X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=profiler-torch.git;a=blobdiff_plain;f=test-profiler.lua;h=44bbee1656ee13d1b3634ef52c71c5afdf22fbf6;hp=b394a332898ac52f72e80201382e50ce95922ef6;hb=faf424f71ac259f3a7e676113cfe4892084b1c93;hpb=21c3bd2eb990e3fa58aa36b0e8fcd8901de5569c diff --git a/test-profiler.lua b/test-profiler.lua index b394a33..44bbee1 100755 --- a/test-profiler.lua +++ b/test-profiler.lua @@ -12,14 +12,36 @@ model:add(nn.Linear(1000, 100)) profiler.decor(model) -for k = 1, 10 do - local input = torch.Tensor(1000, 1000):uniform(-1, 1) - local target = torch.Tensor(input:size(1), 100):uniform() - local criterion = nn.MSECriterion() +local input = torch.Tensor(1000, 1000) +local target = torch.Tensor(input:size(1), 100) +local criterion = nn.MSECriterion() + +local nbSamples = 0 +local modelTime = 0 +local dataTime = 0 + +for k = 1, 5 do + local t1 = sys.clock() + input:uniform(-1, 1) + target:uniform() + + local t2 = sys.clock() + local output = model:forward(input) local loss = criterion:forward(output, target) local dloss = criterion:backward(output, target) model:backward(input, dloss) + + local t3 = sys.clock() + + dataTime = dataTime + (t2 - t1) + modelTime = modelTime + (t3 - t2) + + nbSamples = nbSamples + input:size(1) end -profiler.print(model) +profiler.print(model, nbSamples) + +print('----------------------------------------------------------------------') +print(string.format('Total model time %.02fs', modelTime)) +print(string.format('Total data time %.02fs', dataTime))