X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=profiler-torch.git;a=blobdiff_plain;f=profiler.lua;h=49eaed426dfbddd8c0e2950ecccdb8b990e8160e;hp=91b09157843a86365a03d54b909cc103d13a5a81;hb=faf424f71ac259f3a7e676113cfe4892084b1c93;hpb=21c3bd2eb990e3fa58aa36b0e8fcd8901de5569c diff --git a/profiler.lua b/profiler.lua index 91b0915..49eaed4 100755 --- a/profiler.lua +++ b/profiler.lua @@ -14,12 +14,19 @@ function profiler.decor(model, functionsToDecorate) } for _, name in pairs(functionsToDecorate) do - model.orig = model.orig or {} model.timings = 0 - if model[name] and not model.orig[name] then - model.orig[name] = model[name] - model[name] = function(self, ...) + local functionTable = model + + if not rawget(functionTable, name) then + functionTable = getmetatable(model) + end + + if functionTable[name] and not (functionTable.orig and functionTable.orig[name]) then + print('Profiler decoring ' .. functionTable.__typename .. '.' .. name) + functionTable.orig = functionTable.orig or {} + functionTable.orig[name] = functionTable[name] + functionTable[name] = function(self, ...) local startTime = sys.clock() local result = { self.orig[name](self, unpack({...})) } local endTime = sys.clock() @@ -38,12 +45,19 @@ function profiler.decor(model, functionsToDecorate) end -function profiler.print(model) +function profiler.print(model, nbSamples) print('----------------------------------------------------------------------') print(model) - print(string.format('TIMING %.02fs', model.timings)) + if nbSamples then + print(string.format('acc_time %.02fs (%.1ems/sample)', model.timings, 1000 * model.timings / nbSamples)) + else + print(string.format('acc_time %.02fs', model.timings)) + end + if torch.isTypeOf(model, nn.Container) then - model:applyToModules(profiler.print) + for _, m in ipairs(model.modules) do + profiler.print(m, nbSamples) + end end end