Removed the signature.
[profiler-torch.git] / test-profiler.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
6    Written by Francois Fleuret <francois.fleuret@idiap.ch>
7
8    This file is free software: you can redistribute it and/or modify
9    it under the terms of the GNU General Public License version 3 as
10    published by the Free Software Foundation.
11
12    It is distributed in the hope that it will be useful, but WITHOUT
13    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14    or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
15    License for more details.
16
17    You should have received a copy of the GNU General Public License
18    along with this file.  If not, see <http://www.gnu.org/licenses/>.
19
20 ]]--
21
22 require 'torch'
23 require 'nn'
24
25 require 'profiler'
26
27 -- Create a model
28
29 local w, h, fs = 50, 50, 3
30 local nhu =  (w - fs + 1) * (h - fs + 1)
31
32 local model = nn.Sequential()
33    :add(nn.Sequential()
34            :add(nn.SpatialConvolution(1, 1, fs, fs))
35            :add(nn.Reshape(nhu))
36            :add(nn.Linear(nhu, 1000))
37            :add(nn.ReLU())
38        )
39    :add(nn.Linear(1000, 100))
40
41 -- Decorate it for profiling
42
43 profiler.decorate(model)
44
45 -- Create the data and criterion
46
47 local input = torch.Tensor(1000, 1, h, w)
48 local target = torch.Tensor(input:size(1), 100)
49 local criterion = nn.MSECriterion()
50
51 local nbSamples = 0
52 local modelTime = 0
53 local dataTime = 0
54
55 -- Loop five times through the data forward and backward
56
57 for k = 1, 5 do
58    local t1 = sys.clock()
59
60    input:uniform(-1, 1)
61    target:uniform()
62
63    local t2 = sys.clock()
64
65    local output = model:forward(input)
66    local loss = criterion:forward(output, target)
67    local dloss = criterion:backward(output, target)
68    model:backward(input, dloss)
69
70    local t3 = sys.clock()
71
72    dataTime = dataTime + (t2 - t1)
73    modelTime = modelTime + (t3 - t2)
74
75    nbSamples = nbSamples + input:size(1)
76 end
77
78 -- Print the accumulated timings
79
80 print()
81 -- profiler.color = false
82 profiler.print(model, nbSamples)
83 -- profiler.print(model)
84
85 print(string.format('Total model time %.02fs', modelTime))
86 print(string.format('Total data time %.02fs', dataTime))