#!/usr/bin/env luajit --[[ dyncnn is a deep-learning algorithm for the prediction of interacting object dynamics Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ Written by Francois Fleuret This file is part of dyncnn. dyncnn is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License version 3 as published by the Free Software Foundation. dyncnn is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with dyncnn. If not, see . ]]-- require 'torch' require 'nn' require 'optim' require 'image' require 'pl' ---------------------------------------------------------------------- local opt = lapp[[ --seed (default 1) random seed --learningStateFile (default '') --dataDir (default './data/10p-mg/') --resultDir (default '/tmp/dyncnn') --learningRate (default -1) --momentum (default -1) --nbEpochs (default -1) nb of epochs for the heavy setting --heavy use the heavy configuration --nbChannels (default -1) nb of channels in the internal layers --resultFreq (default 100) --noLog supress logging --exampleInternals (default -1) ]] ---------------------------------------------------------------------- commandLine='' for i = 0, #arg do commandLine = commandLine .. ' \'' .. arg[i] .. '\'' end ---------------------------------------------------------------------- colors = sys.COLORS global = {} function logString(s, c) if global.logFile then global.logFile:write(s) global.logFile:flush() end local c = c or colors.black io.write(c .. s) io.flush() end function logCommand(c) logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue) end logString('commandline: ' .. commandLine .. '\n', colors.blue) logCommand('mkdir -v -p ' .. opt.resultDir) if not opt.noLog then global.logName = opt.resultDir .. '/log' global.logFile = io.open(global.logName, 'a') end ---------------------------------------------------------------------- alreadyLoggedString = {} function logOnce(s) local l = debug.getinfo(1).currentline if not alreadyLoggedString[l] then logString('@line ' .. l .. ' ' .. s, colors.red) alreadyLoggedString[l] = s end end ---------------------------------------------------------------------- nbThreads = os.getenv('TORCH_NB_THREADS') or 1 useGPU = os.getenv('TORCH_USE_GPU') == 'yes' for _, c in pairs({ 'date', 'uname -a', 'git log -1 --format=%H' }) do logCommand(c) end logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n') logString('nbThreads is \'' .. nbThreads .. '\'.\n') ---------------------------------------------------------------------- torch.setnumthreads(nbThreads) torch.setdefaulttensortype('torch.FloatTensor') torch.manualSeed(opt.seed) mynn = {} -- By default, mynn returns the entries from nn local mt = {} function mt.__index(table, key) return nn[key] end setmetatable(mynn, mt) -- These are the tensors that can be kept on the CPU mynn.SlowTensor = torch.Tensor -- These are the tensors that should be moved to the GPU mynn.FastTensor = torch.Tensor ---------------------------------------------------------------------- if useGPU then require 'cutorch' require 'cunn' require 'cudnn' mynn.FastTensor = torch.CudaTensor mynn.SpatialConvolution = cudnn.SpatialConvolution end ---------------------------------------------------------------------- config = {} config.learningRate = 0.1 config.momentum = 0 config.batchSize = 128 config.filterSize = 5 if opt.heavy then logString('Using the heavy configuration.\n') config.nbChannels = 16 config.nbBlocks = 4 config.nbEpochs = 250 config.nbEpochsInit = 100 config.nbTrainSamples = 32768 config.nbValidationSamples = 1024 config.nbTestSamples = 1024 else logString('Using the light configuration.\n') config.nbChannels = 2 config.nbBlocks = 2 config.nbEpochs = 6 config.nbEpochsInit = 3 config.nbTrainSamples = 1024 config.nbValidationSamples = 1024 config.nbTestSamples = 1024 end if opt.nbEpochs > 0 then config.nbEpochs = opt.nbEpochs end if opt.nbChannels > 0 then config.nbChannels = opt.nbChannels end if opt.learningRate > 0 then config.learningRate = opt.learningRate end if opt.momentum >= 0 then config.momentum = opt.momentum end ---------------------------------------------------------------------- function tensorCensus(tensorType, model) local nb = {} local function countThings(m) for k, i in pairs(m) do if torch.type(i) == tensorType then nb[k] = (nb[k] or 0) + i:nElement() end end end model:apply(countThings) return nb end ---------------------------------------------------------------------- function loadData(first, nb, name) logString('Loading data `' .. name .. '\'.\n') local persistentFileName = string.format('%s/persistent_%d_%d.dat', opt.dataDir, first, nb) -- This is at what framerate we work. It is greater than 1 so that -- we can keep on disk sequences at a higher frame rate for videos -- and explaining materials local frameRate = 4 local data if not path.exists(persistentFileName) then logString(string.format('No persistent data structure, creating it (%d samples).\n', nb)) local data = {} data.name = name data.nbSamples = nb data.width = 64 data.height = 64 data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width) data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width) for i = 1, data.nbSamples do local n = i-1 + first-1 local prefix = string.format('%s/%03d/dyn_%06d', opt.dataDir, math.floor(n/1000), n) function localLoad(filename, tensor) local tmp tmp = image.load(filename) tmp:mul(-1.0):add(1.0) tensor:copy(torch.max(tmp, 1)) end localLoad(prefix .. '_world_000.png', data.input[i][1]) localLoad(prefix .. '_grab.png', data.input[i][2]) localLoad(string.format('%s_world_%03d.png', prefix, frameRate), data.target[i][1]) end data.persistentFileName = persistentFileName torch.save(persistentFileName, data) end logCommand('sha256sum -b ' .. persistentFileName) data = torch.load(persistentFileName) return data end ---------------------------------------------------------------------- -- This function gets as input a list of tensors of arbitrary -- dimensions each, but whose two last dimension stands for height x -- width. It creates an image tensor (2d, one channel) with each -- argument tensor unfolded per row. function imageFromTensors(bt, signed) local gap = 1 local tgap = -1 local width = 0 local height = gap for _, t in pairs(bt) do -- print(t:size()) local d = t:dim() local h, w = t:size(d - 1), t:size(d) local n = t:nElement() / (w * h) width = math.max(width, gap + n * (gap + w)) height = height + gap + tgap + gap + h end local e = torch.Tensor(3, height, width):fill(1.0) local y0 = 1 + gap for _, t in pairs(bt) do local d = t:dim() local h, w = t:size(d - 1), t:size(d) local n = t:nElement() / (w * h) local z = t:norm() / math.sqrt(t:nElement()) local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 ) local u = torch.Tensor(t:size()):copy(t):resize(n, h, w) for m = 1, n do for c = 1, 3 do for y = 0, h+1 do e[c][y0 + y - 1][x0 - 1] = 0.0 e[c][y0 + y - 1][x0 + w ] = 0.0 end for x = 0, w+1 do e[c][y0 - 1][x0 + x - 1] = 0.0 e[c][y0 + h ][x0 + x - 1] = 0.0 end end for y = 1, h do for x = 1, w do local v = u[m][y][x] / z local r, g, b if signed then if v < -1 then r, g, b = 0.0, 0.0, 1.0 elseif v > 1 then r, g, b = 1.0, 0.0, 0.0 elseif v >= 0 then r, g, b = 1.0, 1.0 - v, 1.0 - v else r, g, b = 1.0 + v, 1.0 + v, 1.0 end else if v <= 0 then r, g, b = 1.0, 1.0, 1.0 elseif v > 1 then r, g, b = 0.0, 0.0, 0.0 else r, g, b = 1.0 - v, 1.0 - v, 1.0 - v end end e[1][y0 + y - 1][x0 + x - 1] = r e[2][y0 + y - 1][x0 + x - 1] = g e[3][y0 + y - 1][x0 + x - 1] = b end end x0 = x0 + w + gap end y0 = y0 + h + gap + tgap + gap end return e end function collectAllOutputs(model, collection, which) if torch.type(model) == 'nn.Sequential' then for i = 1, #model.modules do collectAllOutputs(model.modules[i], collection, which) end elseif not which or which[torch.type(model)] then local t = torch.type(model.output) if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then collection.nb = collection.nb + 1 collection.outputs[collection.nb] = model.output end end end function saveInternalsImage(model, data, n) -- Explicitely copy to keep input as a mynn.FastTensor local input = mynn.FastTensor(1, 2, data.height, data.width) input:copy(data.input:narrow(1, n, 1)) local output = model:forward(input) local collection = {} collection.outputs = {} collection.nb = 1 collection.outputs[collection.nb] = input local which = {} which['nn.ReLU'] = true collectAllOutputs(model, collection, which) if collection.outputs[collection.nb] ~= model.output then collection.nb = collection.nb + 1 collection.outputs[collection.nb] = model.output end local fileName = string.format('%s/internals_%s_%06d.png', opt.resultDir, data.name, n) logString('Saving ' .. fileName .. '\n') image.save(fileName, imageFromTensors(collection.outputs)) end ---------------------------------------------------------------------- function saveResultImage(model, data, prefix, nbMax, highlight) local l2criterion = nn.MSECriterion() if useGPU then logString('Moving the criterion to the GPU.\n') l2criterion:cuda() end local prefix = prefix or 'result' local result = torch.Tensor(data.height * 4 + 5, data.width + 2) local input = mynn.FastTensor(1, 2, data.height, data.width) local target = mynn.FastTensor(1, 1, data.height, data.width) local nbMax = nbMax or 50 local nb = math.min(nbMax, data.nbSamples) model:evaluate() logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n', nb, prefix, data.name, opt.resultDir)) for n = 1, nb do -- Explicitely copy to keep input as a mynn.FastTensor input:copy(data.input:narrow(1, n, 1)) target:copy(data.target:narrow(1, n, 1)) local output = model:forward(input) local loss = l2criterion:forward(output, target) result:fill(1.0) if highlight then for i = 1, data.height do for j = 1, data.width do local v = data.input[n][1][i][j] result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j] result[1 + i + 1 * (data.height + 1)][1 + j] = v local a = data.target[n][1][i][j] local b = output[1][1][i][j] result[1 + i + 2 * (data.height + 1)][1 + j] = a * math.min(1, 0.1 + 2.0 * math.abs(a - v)) result[1 + i + 3 * (data.height + 1)][1 + j] = b * math.min(1, 0.1 + 2.0 * math.abs(b - v)) end end else for i = 1, data.height do for j = 1, data.width do result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j] result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j] result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j] result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j] end end end result:mul(-1.0):add(1.0) local fileName = string.format('%s/%s_%s_%06d.png', opt.resultDir, prefix, data.name, n) logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName)) image.save(fileName, result) end end ---------------------------------------------------------------------- function createTower(filterSize, nbChannels, nbBlocks) local tower = mynn.Sequential() for b = 1, nbBlocks do local block = mynn.Sequential() block:add(mynn.SpatialConvolution(nbChannels, nbChannels, filterSize, filterSize, 1, 1, (filterSize - 1) / 2, (filterSize - 1) / 2)) block:add(mynn.SpatialBatchNormalization(nbChannels)) block:add(mynn.ReLU(true)) block:add(mynn.SpatialConvolution(nbChannels, nbChannels, filterSize, filterSize, 1, 1, (filterSize - 1) / 2, (filterSize - 1) / 2)) local parallel = mynn.ConcatTable() parallel:add(block):add(mynn.Identity()) tower:add(parallel):add(mynn.CAddTable(true)) tower:add(mynn.SpatialBatchNormalization(nbChannels)) tower:add(mynn.ReLU(true)) end return tower end function createModel(filterSize, nbChannels, nbBlocks) local model = mynn.Sequential() model:add(mynn.SpatialConvolution(2, nbChannels, filterSize, filterSize, 1, 1, (filterSize - 1) / 2, (filterSize - 1) / 2)) model:add(mynn.SpatialBatchNormalization(nbChannels)) model:add(mynn.ReLU(true)) local towerCode = createTower(filterSize, nbChannels, nbBlocks) local towerDecode = createTower(filterSize, nbChannels, nbBlocks) model:add(towerCode) model:add(towerDecode) -- Decode to a single channel, which is the final image model:add(mynn.SpatialConvolution(nbChannels, 1, filterSize, filterSize, 1, 1, (filterSize - 1) / 2, (filterSize - 1) / 2)) return model end ---------------------------------------------------------------------- function fillBatch(data, first, nb, batch, permutation) for k = 1, nb do local i if permutation then i = permutation[first + k - 1] else i = first + k - 1 end batch.input[k] = data.input[i] batch.target[k] = data.target[i] end end function trainModel(model, trainData, validationData, nbEpochs, learningRate, learningStateFile) local l2criterion = nn.MSECriterion() local batchSize = config.batchSize if useGPU then logString('Moving the criterion to the GPU.\n') l2criterion:cuda() end local batch = {} batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width) batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width) local startingEpoch = 1 if model.epoch then startingEpoch = model.epoch + 1 end if model.RNGState then torch.setRNGState(model.RNGState) end logString('Starting training.\n') local parameters, gradParameters = model:getParameters() logString(string.format('model has %d parameters.\n', parameters:storage():size(1))) local averageTrainLoss, averageValidationLoss local trainTime, validationTime local sgdState = { learningRate = config.learningRate, momentum = config.momentum, learningRateDecay = 0 } for e = startingEpoch, nbEpochs do model:training() local permutation = torch.randperm(trainData.nbSamples) local accLoss = 0.0 local nbBatches = 0 local startTime = sys.clock() for b = 1, trainData.nbSamples, batchSize do fillBatch(trainData, b, batchSize, batch, permutation) local opfunc = function(x) -- Surprisingly copy() needs this check if x ~= parameters then parameters:copy(x) end local output = model:forward(batch.input) local loss = l2criterion:forward(output, batch.target) local dLossdOutput = l2criterion:backward(output, batch.target) gradParameters:zero() model:backward(batch.input, dLossdOutput) accLoss = accLoss + loss nbBatches = nbBatches + 1 return loss, gradParameters end optim.sgd(opfunc, parameters, sgdState) end trainTime = sys.clock() - startTime averageTrainLoss = accLoss / nbBatches ---------------------------------------------------------------------- -- Validation losses do model:evaluate() local accLoss = 0.0 local nbBatches = 0 local startTime = sys.clock() for b = 1, validationData.nbSamples, batchSize do fillBatch(validationData, b, batchSize, batch) local output = model:forward(batch.input) accLoss = accLoss + l2criterion:forward(output, batch.target) nbBatches = nbBatches + 1 end validationTime = sys.clock() - startTime averageValidationLoss = accLoss / nbBatches; end logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n', trainTime, 1000 * trainTime / trainData.nbSamples, validationTime, 1000 * validationTime / validationData.nbSamples)) logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss), colors.green) ---------------------------------------------------------------------- -- Save a persistent state so that we can restart from there if learningStateFile then model.RNGState = torch.getRNGState() model.epoch = e model:clearState() logString('Writing ' .. learningStateFile .. '.\n') torch.save(learningStateFile, model) end ---------------------------------------------------------------------- -- Save a duplicate of the persistent state from time to time if opt.resultFreq > 0 and e%opt.resultFreq == 0 then torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model) saveResultImage(model, trainData) saveResultImage(model, validationData) end end end function createAndTrainModel(trainData, validationData) local model local learningStateFile = opt.learningStateFile if learningStateFile == '' then learningStateFile = opt.resultDir .. '/learning.state' end local gotlearningStateFile logString('Using the learning state file ' .. learningStateFile .. '\n') if pcall(function () model = torch.load(learningStateFile) end) then gotlearningStateFile = true else model = createModel(config.filterSize, config.nbChannels, config.nbBlocks) if useGPU then logString('Moving the model to the GPU.\n') model:cuda() end end logString(tostring(model) .. '\n') if gotlearningStateFile then logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch), colors.red) end if opt.exampleInternals > 0 then saveInternalsImage(model, validationData, opt.exampleInternals) os.exit(0) end trainModel(model, trainData, validationData, config.nbEpochs, config.learningRate, learningStateFile) return model end for i, j in pairs(config) do logString('config ' .. i .. ' = \'' .. j ..'\'\n') end local trainData = loadData(1, config.nbTrainSamples, 'train') local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation') local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test') local model = createAndTrainModel(trainData, validationData) saveResultImage(model, trainData) saveResultImage(model, validationData) saveResultImage(model, testData, nil, testData.nbSamples)