--[[ Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ Written by Francois Fleuret This file is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License version 3 as published by the Free Software Foundation. It is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this file. If not, see . ]]-- require 'torch' require 'nn' require 'sys' profiler = {} profiler.color = true profiler.colors = function(name) if profiler.color then return sys.COLORS[name] else return '' end end function profiler.decorate(model, functionsToDecorate) local functionsToDecorate = functionsToDecorate or { 'updateOutput', 'backward' } for _, name in pairs(functionsToDecorate) do model.accTime = {} -- We decorate the class and not the object, otherwise we cannot -- save models anymore. if rawget(model, name) then error('We decorate the classes, not the objects, and there is a `' .. name .. '\' function in ' .. tostring(model)) end local toDecorate = getmetatable(model) while not rawget(toDecorate, name) do toDecorate = getmetatable(toDecorate) end local nameOrig = name .. '__orig' if not toDecorate[nameOrig] then -- print('Decorating ' .. toDecorate.__typename .. '.' .. name) toDecorate[nameOrig] = toDecorate[name] toDecorate[name] = function(self, ...) local startTime = sys.clock() local result = { self[nameOrig](self, unpack({...})) } local endTime = sys.clock() self.accTime[name] = (self.accTime[name] or 0) + endTime - startTime return unpack(result) end end end if torch.isTypeOf(model, nn.Container) then for _, m in ipairs(model.modules) do profiler.decorate(m, functionsToDecorate) end end end function profiler.timingString(l, t, nbSamples, totalTime) local s s = string.format('%s %.02fs %s[%.02f%%]', l, t, profiler.colors('blue'), 100 * t / totalTime ) if nbSamples then s = s .. string.format(profiler.colors('green') .. ' (%.01fmus/sample)', 1e6 * t / nbSamples) end s = s .. profiler.colors('black') return s end function profiler.print(model, nbSamples, totalTime, indent) local indent = indent or '' local hint if not model.accTime then error('The model does not seem decorated for profiling.') end local localTotal = 0 for _, t in pairs(model.accTime) do localTotal = localTotal + t end totalTime = totalTime or localTotal if torch.isTypeOf(model, nn.Container) then hint = ' ' else if profiler.color then hint = ' ' .. profiler.colors('red') else hint = '* ' end end print(profiler.timingString(indent .. hint .. model.__typename, localTotal, nbSamples, totalTime)) for l, t in pairs(model.accTime) do print(profiler.timingString(indent .. ' :' .. l, t, nbSamples, totalTime)) end print() if torch.isTypeOf(model, nn.Container) then for _, m in ipairs(model.modules) do profiler.print(m, nbSamples, totalTime, indent .. ' ') end end end return profiler