From faf424f71ac259f3a7e676113cfe4892084b1c93 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sun, 4 Dec 2016 15:49:41 +0100 Subject: [PATCH] Update. --- profiler.lua | 28 +++++++++++++++++++++------- test-profiler.lua | 32 +++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/profiler.lua b/profiler.lua index 91b0915..49eaed4 100755 --- a/profiler.lua +++ b/profiler.lua @@ -14,12 +14,19 @@ function profiler.decor(model, functionsToDecorate) } for _, name in pairs(functionsToDecorate) do - model.orig = model.orig or {} model.timings = 0 - if model[name] and not model.orig[name] then - model.orig[name] = model[name] - model[name] = function(self, ...) + local functionTable = model + + if not rawget(functionTable, name) then + functionTable = getmetatable(model) + end + + if functionTable[name] and not (functionTable.orig and functionTable.orig[name]) then + print('Profiler decoring ' .. functionTable.__typename .. '.' .. name) + functionTable.orig = functionTable.orig or {} + functionTable.orig[name] = functionTable[name] + functionTable[name] = function(self, ...) local startTime = sys.clock() local result = { self.orig[name](self, unpack({...})) } local endTime = sys.clock() @@ -38,12 +45,19 @@ function profiler.decor(model, functionsToDecorate) end -function profiler.print(model) +function profiler.print(model, nbSamples) print('----------------------------------------------------------------------') print(model) - print(string.format('TIMING %.02fs', model.timings)) + if nbSamples then + print(string.format('acc_time %.02fs (%.1ems/sample)', model.timings, 1000 * model.timings / nbSamples)) + else + print(string.format('acc_time %.02fs', model.timings)) + end + if torch.isTypeOf(model, nn.Container) then - model:applyToModules(profiler.print) + for _, m in ipairs(model.modules) do + profiler.print(m, nbSamples) + end end end diff --git a/test-profiler.lua b/test-profiler.lua index b394a33..44bbee1 100755 --- a/test-profiler.lua +++ b/test-profiler.lua @@ -12,14 +12,36 @@ model:add(nn.Linear(1000, 100)) profiler.decor(model) -for k = 1, 10 do - local input = torch.Tensor(1000, 1000):uniform(-1, 1) - local target = torch.Tensor(input:size(1), 100):uniform() - local criterion = nn.MSECriterion() +local input = torch.Tensor(1000, 1000) +local target = torch.Tensor(input:size(1), 100) +local criterion = nn.MSECriterion() + +local nbSamples = 0 +local modelTime = 0 +local dataTime = 0 + +for k = 1, 5 do + local t1 = sys.clock() + input:uniform(-1, 1) + target:uniform() + + local t2 = sys.clock() + local output = model:forward(input) local loss = criterion:forward(output, target) local dloss = criterion:backward(output, target) model:backward(input, dloss) + + local t3 = sys.clock() + + dataTime = dataTime + (t2 - t1) + modelTime = modelTime + (t3 - t2) + + nbSamples = nbSamples + input:size(1) end -profiler.print(model) +profiler.print(model, nbSamples) + +print('----------------------------------------------------------------------') +print(string.format('Total model time %.02fs', modelTime)) +print(string.format('Total data time %.02fs', dataTime)) -- 2.20.1