Update.
[profiler-torch.git] / profiler.lua
1
2 --[[
3
4    Written by Francois Fleuret (francois@fleuret.org)
5
6    This is free and unencumbered software released into the public
7    domain.
8
9    Anyone is free to copy, modify, publish, use, compile, sell, or
10    distribute this software, either in source code form or as a
11    compiled binary, for any purpose, commercial or non-commercial, and
12    by any means.
13
14    In jurisdictions that recognize copyright laws, the author or
15    authors of this software dedicate any and all copyright interest in
16    the software to the public domain. We make this dedication for the
17    benefit of the public at large and to the detriment of our heirs
18    and successors. We intend this dedication to be an overt act of
19    relinquishment in perpetuity of all present and future rights to
20    this software under copyright law.
21
22    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25    NONINFRINGEMENT.  IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
26    CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
27    CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
28    WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
29
30    For more information, please refer to <http://unlicense.org/>
31
32 ]]--
33
34 require 'torch'
35 require 'nn'
36 require 'sys'
37
38 profiler = {}
39
40 profiler.color = true
41
42 profiler.colors = function(name)
43    if profiler.color then
44       return sys.COLORS[name]
45    else
46       return ''
47    end
48 end
49
50 function profiler.decorate(model, functionsToDecorate)
51
52    local functionsToDecorate = functionsToDecorate or
53       {
54          'updateOutput',
55          'backward'
56       }
57
58    for _, name in pairs(functionsToDecorate) do
59       model.accTime = {}
60
61       local nameOrig = name .. '__orig'
62
63       -- We decorate the class and not the object, otherwise we cannot
64       -- save models anymore.
65
66       if rawget(model, name) then
67          error('We decorate the classes, not the objects, and there is a `'
68                   .. name
69                   .. '\' function in '
70                   .. tostring(model))
71       end
72
73       local toDecorate = model
74
75       while not rawget(toDecorate, name) do
76          toDecorate = getmetatable(toDecorate)
77       end
78
79       if not toDecorate[nameOrig] then
80          -- print('Decorating ' .. toDecorate.__typename .. '.' .. name)
81          toDecorate[nameOrig] = toDecorate[name]
82          toDecorate[name] = function(self, ...)
83             local startTime = sys.clock()
84             local result = { self[nameOrig](self, unpack({...})) }
85             local endTime = sys.clock()
86             self.accTime[name] = (self.accTime[name] or 0) + endTime - startTime
87             return unpack(result)
88          end
89       end
90
91    end
92
93    if torch.isTypeOf(model, nn.Container) then
94       for _, m in ipairs(model.modules) do
95          profiler.decorate(m, functionsToDecorate)
96       end
97    end
98
99 end
100
101 function profiler.timing(l, t, nbSamples, totalTime)
102    local s
103
104    s = string.format('%s %.02fs %s[%.02f%%]',
105                      l, t,
106                      profiler.colors('blue'),
107                      100 * t / totalTime
108    )
109
110    if nbSamples then
111       s = s .. string.format(profiler.colors('green') .. ' (%.01fmus/sample)', 1e6 * t / nbSamples)
112    end
113
114    s = s .. profiler.colors('black')
115
116    return s
117 end
118
119 function profiler.print(model, nbSamples, totalTime, indent)
120    local indent = indent or ''
121    local hint
122
123    if not model.accTime then
124       error('The model does not seem decorated for profiling.')
125    end
126
127    local localTotal = 0
128    for _, t in pairs(model.accTime) do
129       localTotal = localTotal + t
130    end
131
132    totalTime = totalTime or localTotal
133
134    if torch.isTypeOf(model, nn.Container) then
135       hint = ' '
136    else
137       if profiler.color then
138          hint = ' '
139       else
140          hint = '*'
141       end
142       hint = hint .. profiler.colors('red')
143    end
144
145    print(profiler.timing(indent .. hint .. ' ' .. model.__typename,
146                          localTotal, nbSamples, totalTime))
147
148    for l, t in pairs(model.accTime) do
149       print(profiler.timing(indent .. '  :' .. l, t, nbSamples, totalTime))
150    end
151
152    print()
153
154    if torch.isTypeOf(model, nn.Container) then
155       for _, m in ipairs(model.modules) do
156          profiler.print(m, nbSamples, totalTime, indent .. '  ')
157       end
158    end
159 end
160
161 return profiler