Update.
[profiler-torch.git] / profiler.lua
1
2 require 'torch'
3 require 'nn'
4 require 'sys'
5
6 profiler = {}
7
8 function profiler.decor(model, functionsToDecorate)
9
10    local functionsToDecorate = functionsToDecorate or
11       {
12          'updateOutput',
13          'backward'
14       }
15
16    for _, name in pairs(functionsToDecorate) do
17       model.timings = 0
18
19       local functionTable = model
20
21       if not rawget(functionTable, name) then
22          functionTable = getmetatable(model)
23       end
24
25       if functionTable[name] and not (functionTable.orig and functionTable.orig[name]) then
26          print('Profiler decoring ' .. functionTable.__typename .. '.' .. name)
27          functionTable.orig = functionTable.orig or {}
28          functionTable.orig[name] = functionTable[name]
29          functionTable[name] = function(self, ...)
30             local startTime = sys.clock()
31             local result = { self.orig[name](self, unpack({...})) }
32             local endTime = sys.clock()
33             self.timings = self.timings + endTime - startTime
34             return unpack(result)
35          end
36       end
37
38    end
39
40    if torch.isTypeOf(model, nn.Container) then
41       for _, m in ipairs(model.modules) do
42          profiler.decor(m, functionsToDecorate)
43       end
44    end
45
46 end
47
48 function profiler.print(model, nbSamples)
49    print('----------------------------------------------------------------------')
50    print(model)
51    if nbSamples then
52       print(string.format('acc_time %.02fs (%.1ems/sample)', model.timings, 1000 * model.timings / nbSamples))
53    else
54       print(string.format('acc_time %.02fs', model.timings))
55    end
56
57    if torch.isTypeOf(model, nn.Container) then
58       for _, m in ipairs(model.modules) do
59          profiler.print(m, nbSamples)
60       end
61    end
62 end
63
64 return profiler