91b09157843a86365a03d54b909cc103d13a5a81
[profiler-torch.git] / profiler.lua
1
2 require 'torch'
3 require 'nn'
4 require 'sys'
5
6 profiler = {}
7
8 function profiler.decor(model, functionsToDecorate)
9
10    local functionsToDecorate = functionsToDecorate or
11       {
12          'updateOutput',
13          'backward'
14       }
15
16    for _, name in pairs(functionsToDecorate) do
17       model.orig = model.orig or {}
18       model.timings = 0
19
20       if model[name] and not model.orig[name] then
21          model.orig[name] = model[name]
22          model[name] = function(self, ...)
23             local startTime = sys.clock()
24             local result = { self.orig[name](self, unpack({...})) }
25             local endTime = sys.clock()
26             self.timings = self.timings + endTime - startTime
27             return unpack(result)
28          end
29       end
30
31    end
32
33    if torch.isTypeOf(model, nn.Container) then
34       for _, m in ipairs(model.modules) do
35          profiler.decor(m, functionsToDecorate)
36       end
37    end
38
39 end
40
41 function profiler.print(model)
42    print('----------------------------------------------------------------------')
43    print(model)
44    print(string.format('TIMING %.02fs', model.timings))
45    if torch.isTypeOf(model, nn.Container) then
46       model:applyToModules(profiler.print)
47    end
48 end
49
50 return profiler