Fixed a stupid and embarrassing typo.
[profiler-torch.git] / test-profiler.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    Written by Francois Fleuret (francois@fleuret.org)
6
7    This is free and unencumbered software released into the public
8    domain.
9
10    Anyone is free to copy, modify, publish, use, compile, sell, or
11    distribute this software, either in source code form or as a
12    compiled binary, for any purpose, commercial or non-commercial, and
13    by any means.
14
15    In jurisdictions that recognize copyright laws, the author or
16    authors of this software dedicate any and all copyright interest in
17    the software to the public domain. We make this dedication for the
18    benefit of the public at large and to the detriment of our heirs
19    and successors. We intend this dedication to be an overt act of
20    relinquishment in perpetuity of all present and future rights to
21    this software under copyright law.
22
23    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26    NONINFRINGEMENT.  IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
27    CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
28    CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
29    WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
30
31    For more information, please refer to <http://unlicense.org/>
32
33 ]]--
34
35 require 'torch'
36 require 'nn'
37
38 require 'profiler'
39
40 -- Create a model
41
42 local model = nn.Sequential()
43    :add(nn.Sequential()
44            :add(nn.Linear(1000, 1000))
45            :add(nn.ReLU())
46        )
47    :add(nn.Linear(1000, 100))
48
49 -- Decor it for profiling
50
51 profiler.decorate(model)
52 print()
53
54 torch.save('model.t7', model)
55
56 -- Create the data and criterion
57
58 local input = torch.Tensor(1000, 1000)
59 local target = torch.Tensor(input:size(1), 100)
60 local criterion = nn.MSECriterion()
61
62 local nbSamples = 0
63 local modelTime = 0
64 local dataTime = 0
65
66 -- Loop five times through the data forward and backward
67
68 for k = 1, 5 do
69    local t1 = sys.clock()
70
71    input:uniform(-1, 1)
72    target:uniform()
73
74    local t2 = sys.clock()
75
76    local output = model:forward(input)
77    local loss = criterion:forward(output, target)
78    local dloss = criterion:backward(output, target)
79    model:backward(input, dloss)
80
81    local t3 = sys.clock()
82
83    dataTime = dataTime + (t2 - t1)
84    modelTime = modelTime + (t3 - t2)
85
86    nbSamples = nbSamples + input:size(1)
87 end
88
89 -- Print the accumulated timings
90
91 profiler.print(model, nbSamples)
92 -- profiler.print(model)
93
94 print()
95 print(string.format('Total model time %.02fs', modelTime))
96 print(string.format('Total data time %.02fs', dataTime))