require 'torch' require 'nn' require 'sys' profiler = {} function profiler.decor(model, functionsToDecorate) local functionsToDecorate = functionsToDecorate or { 'updateOutput', 'backward' } for _, name in pairs(functionsToDecorate) do model.orig = model.orig or {} model.timings = 0 if model[name] and not model.orig[name] then model.orig[name] = model[name] model[name] = function(self, ...) local startTime = sys.clock() local result = { self.orig[name](self, unpack({...})) } local endTime = sys.clock() self.timings = self.timings + endTime - startTime return unpack(result) end end end if torch.isTypeOf(model, nn.Container) then for _, m in ipairs(model.modules) do profiler.decor(m, functionsToDecorate) end end end function profiler.print(model) print('----------------------------------------------------------------------') print(model) print(string.format('TIMING %.02fs', model.timings)) if torch.isTypeOf(model, nn.Container) then model:applyToModules(profiler.print) end end return profiler