Update.
[profiler-torch.git] / test-profiler.lua
1 #!/usr/bin/env luajit
2
3 require 'torch'
4 require 'nn'
5
6 require 'profiler'
7
8 local model = nn.Sequential()
9 model:add(nn.Linear(1000, 1000))
10 model:add(nn.ReLU())
11 model:add(nn.Linear(1000, 100))
12
13 profiler.decor(model)
14
15 local input = torch.Tensor(1000, 1000)
16 local target = torch.Tensor(input:size(1), 100)
17 local criterion = nn.MSECriterion()
18
19 local nbSamples = 0
20 local modelTime = 0
21 local dataTime = 0
22
23 for k = 1, 5 do
24    local t1 = sys.clock()
25    input:uniform(-1, 1)
26    target:uniform()
27
28    local t2 = sys.clock()
29
30    local output = model:forward(input)
31    local loss = criterion:forward(output, target)
32    local dloss = criterion:backward(output, target)
33    model:backward(input, dloss)
34
35    local t3 = sys.clock()
36
37    dataTime = dataTime + (t2 - t1)
38    modelTime = modelTime + (t3 - t2)
39
40    nbSamples = nbSamples + input:size(1)
41 end
42
43 profiler.print(model, nbSamples)
44
45 print('----------------------------------------------------------------------')
46 print(string.format('Total model time %.02fs', modelTime))
47 print(string.format('Total data time %.02fs', dataTime))