X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dyncnn.git;a=blobdiff_plain;f=dyncnn.lua;h=0bb780c717182ca04ebc6a389600baa49fe23720;hp=839431ab40ccdd33d277c2927295ef40016c0ef2;hb=HEAD;hpb=4cab0b04a02e270f4e6bce11a763ddfe0a2ad2ae diff --git a/dyncnn.lua b/dyncnn.lua index 839431a..0bb780c 100755 --- a/dyncnn.lua +++ b/dyncnn.lua @@ -28,343 +28,115 @@ require 'torch' require 'nn' require 'optim' require 'image' -require 'pl' ----------------------------------------------------------------------- - -local opt = lapp[[ - --seed (default 1) random seed - - --learningStateFile (default '') - --dataDir (default './data/10p-mg/') - --resultDir (default '/tmp/dyncnn') - - --learningRate (default -1) - --momentum (default -1) - --nbEpochs (default -1) nb of epochs for the heavy setting - - --heavy use the heavy configuration - --nbChannels (default -1) nb of channels in the internal layers - --resultFreq (default 100) - - --noLog supress logging - - --exampleInternals (default -1) -]] +require 'fftb' ---------------------------------------------------------------------- +-- Command line arguments -commandLine='' -for i = 0, #arg do - commandLine = commandLine .. ' \'' .. arg[i] .. '\'' -end +local cmd = torch.CmdLine() ----------------------------------------------------------------------- +cmd:text('General setup') -colors = sys.COLORS +cmd:option('-seed', 1, 'initial random seed') +cmd:option('-nbThreads', defaultNbThreads, 'how many threads (environment variable TORCH_NB_THREADS)') +cmd:option('-useGPU', defaultUseGPU, 'should we use cuda (environment variable TORCH_USE_GPU)') +cmd:option('-fastGPU', true, 'should we go as fast as possible, possibly non-deterministically') -global = {} +cmd:text('') +cmd:text('Log') -function logString(s, c) - if global.logFile then - global.logFile:write(s) - global.logFile:flush() - end - local c = c or colors.black - io.write(c .. s) - io.flush() -end +cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images') +cmd:option('-exampleInternals', '', 'list of comma-separated indices for inner activation images') +cmd:option('-noLog', false, 'should we prevent logging') +cmd:option('-rundir', '', 'the directory for results') +cmd:option('-deltaImages', false, 'should we highlight the difference in result images') -function logCommand(c) - logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue) -end +cmd:text('') +cmd:text('Network structure') -logString('commandline: ' .. commandLine .. '\n', colors.blue) +cmd:option('-filterSize', 5) +cmd:option('-nbChannels', 16) +cmd:option('-nbBlocks', 8) -logCommand('mkdir -v -p ' .. opt.resultDir) +cmd:text('') +cmd:text('Training') -if not opt.noLog then - global.logName = opt.resultDir .. '/log' - global.logFile = io.open(global.logName, 'a') -end +cmd:option('-nbEpochs', 1000, 'nb of epochs for the heavy setting') +cmd:option('-learningRate', 0.1, 'learning rate') +cmd:option('-batchSize', 128, 'size of the mini-batches') +cmd:option('-nbTrainSamples', 32768) +cmd:option('-nbValidationSamples', 1024) +cmd:option('-nbTestSamples', 1024) ----------------------------------------------------------------------- +cmd:text('') +cmd:text('Problem to solve') -alreadyLoggedString = {} +cmd:option('-dataDir', './data/10p-mg', 'data directory') -function logOnce(s) - local l = debug.getinfo(1).currentline - if not alreadyLoggedString[l] then - logString('@line ' .. l .. ' ' .. s, colors.red) - alreadyLoggedString[l] = s - end -end +cmd:addTime('DYNCNN','%F %T') ----------------------------------------------------------------------- +params = cmd:parse(arg) -nbThreads = os.getenv('TORCH_NB_THREADS') or 1 +---------------------------------------------------------------------- -useGPU = os.getenv('TORCH_USE_GPU') == 'yes' +fftbInit(cmd, params) -for _, c in pairs({ 'date', - 'uname -a', - 'git log -1 --format=%H' +for _, c in pairs({ + 'date', + 'uname -a', + 'git log -1 --format=%H' }) do logCommand(c) end -logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n') - -logString('nbThreads is \'' .. nbThreads .. '\'.\n') - ----------------------------------------------------------------------- - -torch.setnumthreads(nbThreads) -torch.setdefaulttensortype('torch.FloatTensor') -torch.manualSeed(opt.seed) - -mynn = {} - --- By default, mynn returns the entries from nn -local mt = {} -function mt.__index(table, key) - return nn[key] -end -setmetatable(mynn, mt) - --- These are the tensors that can be kept on the CPU -mynn.SlowTensor = torch.Tensor --- These are the tensors that should be moved to the GPU -mynn.FastTensor = torch.Tensor - ----------------------------------------------------------------------- - -if useGPU then - require 'cutorch' - require 'cunn' - require 'cudnn' - mynn.FastTensor = torch.CudaTensor - mynn.SpatialConvolution = cudnn.SpatialConvolution -end - ----------------------------------------------------------------------- - -config = {} -config.learningRate = 0.1 -config.momentum = 0 -config.batchSize = 128 -config.filterSize = 5 - -if opt.heavy then - - logString('Using the heavy configuration.\n') - config.nbChannels = 16 - config.nbBlocks = 4 - config.nbEpochs = 250 - config.nbEpochsInit = 100 - config.nbTrainSamples = 32768 - config.nbValidationSamples = 1024 - config.nbTestSamples = 1024 - -else - - logString('Using the light configuration.\n') - config.nbChannels = 2 - config.nbBlocks = 2 - config.nbEpochs = 6 - config.nbEpochsInit = 3 - config.nbTrainSamples = 1024 - config.nbValidationSamples = 1024 - config.nbTestSamples = 1024 - -end - -if opt.nbEpochs > 0 then - config.nbEpochs = opt.nbEpochs -end - -if opt.nbChannels > 0 then - config.nbChannels = opt.nbChannels -end - -if opt.learningRate > 0 then - config.learningRate = opt.learningRate -end - -if opt.momentum >= 0 then - config.momentum = opt.momentum -end - ---------------------------------------------------------------------- -function tensorCensus(tensorType, model) - - local nb = {} - - local function countThings(m) - for k, i in pairs(m) do - if torch.type(i) == tensorType then - nb[k] = (nb[k] or 0) + i:nElement() - end - end - end +function loadData(first, nb, name) + print('Loading data `' .. name .. '\'.') - model:apply(countThings) + local data = {} - return nb + data.name = name + data.nbSamples = nb + data.width = 64 + data.height = 64 -end + data.input = ffnn.SlowTensor(data.nbSamples, 2, data.height, data.width) + data.target = ffnn.SlowTensor(data.nbSamples, 1, data.height, data.width) ----------------------------------------------------------------------- + for i = 1, data.nbSamples do + local n = i-1 + first-1 + local frame = image.load(string.format('%s/%03d/dyn_%06d.png', + params.dataDir, + math.floor(n/1000), n)) -function loadData(first, nb, name) - logString('Loading data `' .. name .. '\'.\n') - - local persistentFileName = string.format('%s/persistent_%d_%d.dat', - opt.dataDir, - first, - nb) - - -- This is at what framerate we work. It is greater than 1 so that - -- we can keep on disk sequences at a higher frame rate for videos - -- and explaining materials - - local frameRate = 4 - - local data - - if not path.exists(persistentFileName) then - logString(string.format('No persistent data structure, creating it (%d samples).\n', nb)) - local data = {} - data.name = name - data.nbSamples = nb - data.width = 64 - data.height = 64 - data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width) - data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width) - - for i = 1, data.nbSamples do - local n = i-1 + first-1 - local prefix = string.format('%s/%03d/dyn_%06d', - opt.dataDir, - math.floor(n/1000), n) - - function localLoad(filename, tensor) - local tmp - tmp = image.load(filename) - tmp:mul(-1.0):add(1.0) - tensor:copy(torch.max(tmp, 1)) - end + frame:mul(-1.0):add(1.0) + frame = frame:max(1):select(1, 1) - localLoad(prefix .. '_world_000.png', data.input[i][1]) - localLoad(prefix .. '_grab.png', data.input[i][2]) - localLoad(string.format('%s_world_%03d.png', prefix, frameRate), - data.target[i][1]) - end + data.input[i][1]:copy(frame:sub(0 * data.height + 1, 1 * data.height, + 1 * data.width + 1, 2 * data.width)) - data.persistentFileName = persistentFileName + data.input[i][2]:copy(frame:sub(0 * data.height + 1, 1 * data.height, + 0 * data.width + 1, 1 * data.width)) - torch.save(persistentFileName, data) + data.target[i][1]:copy(frame:sub(1 * data.height + 1, 2 * data.height, + 1 * data.width + 1, 2 * data.width)) end - logCommand('sha256sum -b ' .. persistentFileName) - - data = torch.load(persistentFileName) - return data end ---------------------------------------------------------------------- --- This function gets as input a list of tensors of arbitrary --- dimensions each, but whose two last dimension stands for height x --- width. It creates an image tensor (2d, one channel) with each --- argument tensor unfolded per row. - -function imageFromTensors(bt, signed) - local gap = 1 - local tgap = -1 - local width = 0 - local height = gap - - for _, t in pairs(bt) do - -- print(t:size()) - local d = t:dim() - local h, w = t:size(d - 1), t:size(d) - local n = t:nElement() / (w * h) - width = math.max(width, gap + n * (gap + w)) - height = height + gap + tgap + gap + h - end - - local e = torch.Tensor(3, height, width):fill(1.0) - local y0 = 1 + gap - - for _, t in pairs(bt) do - local d = t:dim() - local h, w = t:size(d - 1), t:size(d) - local n = t:nElement() / (w * h) - local z = t:norm() / math.sqrt(t:nElement()) - - local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 ) - local u = torch.Tensor(t:size()):copy(t):resize(n, h, w) - for m = 1, n do - - for c = 1, 3 do - for y = 0, h+1 do - e[c][y0 + y - 1][x0 - 1] = 0.0 - e[c][y0 + y - 1][x0 + w ] = 0.0 - end - for x = 0, w+1 do - e[c][y0 - 1][x0 + x - 1] = 0.0 - e[c][y0 + h ][x0 + x - 1] = 0.0 - end - end - - for y = 1, h do - for x = 1, w do - local v = u[m][y][x] / z - local r, g, b - if signed then - if v < -1 then - r, g, b = 0.0, 0.0, 1.0 - elseif v > 1 then - r, g, b = 1.0, 0.0, 0.0 - elseif v >= 0 then - r, g, b = 1.0, 1.0 - v, 1.0 - v - else - r, g, b = 1.0 + v, 1.0 + v, 1.0 - end - else - if v <= 0 then - r, g, b = 1.0, 1.0, 1.0 - elseif v > 1 then - r, g, b = 0.0, 0.0, 0.0 - else - r, g, b = 1.0 - v, 1.0 - v, 1.0 - v - end - end - e[1][y0 + y - 1][x0 + x - 1] = r - e[2][y0 + y - 1][x0 + x - 1] = g - e[3][y0 + y - 1][x0 + x - 1] = b - end - end - x0 = x0 + w + gap - end - y0 = y0 + h + gap + tgap + gap - end - - return e -end - function collectAllOutputs(model, collection, which) if torch.type(model) == 'nn.Sequential' then for i = 1, #model.modules do collectAllOutputs(model.modules[i], collection, which) end elseif not which or which[torch.type(model)] then - local t = torch.type(model.output) - if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then + if torch.isTensor(model.output) then collection.nb = collection.nb + 1 collection.outputs[collection.nb] = model.output end @@ -372,8 +144,8 @@ function collectAllOutputs(model, collection, which) end function saveInternalsImage(model, data, n) - -- Explicitely copy to keep input as a mynn.FastTensor - local input = mynn.FastTensor(1, 2, data.height, data.width) + -- Explicitely copy to keep input as a ffnn.FastTensor + local input = ffnn.FastTensor(1, 2, data.height, data.width) input:copy(data.input:narrow(1, n, 1)) local output = model:forward(input) @@ -383,9 +155,13 @@ function saveInternalsImage(model, data, n) collection.nb = 1 collection.outputs[collection.nb] = input - local which = {} - which['nn.ReLU'] = true - collectAllOutputs(model, collection, which) + collectAllOutputs(model, collection, + { + ['nn.ReLU'] = true, + ['cunn.ReLU'] = true, + ['cudnn.ReLU'] = true, + } + ) if collection.outputs[collection.nb] ~= model.output then collection.nb = collection.nb + 1 @@ -393,27 +169,35 @@ function saveInternalsImage(model, data, n) end local fileName = string.format('%s/internals_%s_%06d.png', - opt.resultDir, + params.rundir, data.name, n) - logString('Saving ' .. fileName .. '\n') + print('Saving ' .. fileName) image.save(fileName, imageFromTensors(collection.outputs)) end ---------------------------------------------------------------------- -function saveResultImage(model, data, prefix, nbMax, highlight) - local l2criterion = nn.MSECriterion() +function highlightImage(a, b) + if params.deltaImages then + local h = torch.csub(a, b):abs() + h:div(1/h:max()):mul(0.9):add(0.1) + return torch.cmul(a, h) + else + return a + end +end + +function saveResultImage(model, data, nbMax) + local criterion = nn.MSECriterion() - if useGPU then - logString('Moving the criterion to the GPU.\n') - l2criterion:cuda() + if params.useGPU then + print('Moving the criterion to the GPU.') + criterion:cuda() end - local prefix = prefix or 'result' - local result = torch.Tensor(data.height * 4 + 5, data.width + 2) - local input = mynn.FastTensor(1, 2, data.height, data.width) - local target = mynn.FastTensor(1, 1, data.height, data.width) + local input = ffnn.FastTensor(1, 2, data.height, data.width) + local target = ffnn.FastTensor(1, 1, data.height, data.width) local nbMax = nbMax or 50 @@ -421,114 +205,111 @@ function saveResultImage(model, data, prefix, nbMax, highlight) model:evaluate() - logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n', - nb, prefix, data.name, - opt.resultDir)) + printf('Write %d result images for `%s\'.', nb, data.name) + + local lossFile = io.open(params.rundir .. '/result_' .. data.name .. '_losses.dat', 'w') for n = 1, nb do - -- Explicitely copy to keep input as a mynn.FastTensor + -- Explicitely copy to keep input as a ffnn.FastTensor input:copy(data.input:narrow(1, n, 1)) target:copy(data.target:narrow(1, n, 1)) local output = model:forward(input) + local loss = criterion:forward(output, target) - local loss = l2criterion:forward(output, target) - - result:fill(1.0) - - if highlight then - for i = 1, data.height do - for j = 1, data.width do - local v = data.input[n][1][i][j] - result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j] - result[1 + i + 1 * (data.height + 1)][1 + j] = v - local a = data.target[n][1][i][j] - local b = output[1][1][i][j] - result[1 + i + 2 * (data.height + 1)][1 + j] = - a * math.min(1, 0.1 + 2.0 * math.abs(a - v)) - result[1 + i + 3 * (data.height + 1)][1 + j] = - b * math.min(1, 0.1 + 2.0 * math.abs(b - v)) - end - end - else - for i = 1, data.height do - for j = 1, data.width do - result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j] - result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j] - result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j] - result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j] - end - end - end + output = ffnn.SlowTensor(output:size()):copy(output) - result:mul(-1.0):add(1.0) + -- We use our magical img.lua to create the result images + + local comp - local fileName = string.format('%s/%s_%s_%06d.png', - opt.resultDir, - prefix, - data.name, n) + comp = { + { + vertical = true, + { pad = 1, data.input[n][1] }, + { pad = 1, data.input[n][2] }, + { pad = 1, highlightImage(data.target[n][1], data.input[n][1]) }, + { pad = 1, highlightImage(output[1][1], data.input[n][1]) }, + } + } - logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName)) + local result = combineImages(1.0, comp) - image.save(fileName, result) + result:mul(-1.0):add(1.0) + + local fileName = string.format('result_%s_%06d.png', data.name, n) + image.save(params.rundir .. '/' .. fileName, result) + lossFile:write(string.format('%f %s\n', loss, fileName)) end end ---------------------------------------------------------------------- function createTower(filterSize, nbChannels, nbBlocks) - local tower = mynn.Sequential() - for b = 1, nbBlocks do - local block = mynn.Sequential() + local tower + + if nbBlocks == 0 then + + tower = nn.Identity() + + else + + tower = ffnn.Sequential() + + for b = 1, nbBlocks do + local block = ffnn.Sequential() + + block:add(ffnn.SpatialConvolution(nbChannels, + nbChannels, + filterSize, filterSize, + 1, 1, + (filterSize - 1) / 2, (filterSize - 1) / 2)) + block:add(ffnn.SpatialBatchNormalization(nbChannels)) + block:add(ffnn.ReLU(true)) - block:add(mynn.SpatialConvolution(nbChannels, - nbChannels, - filterSize, filterSize, - 1, 1, - (filterSize - 1) / 2, (filterSize - 1) / 2)) - block:add(mynn.SpatialBatchNormalization(nbChannels)) - block:add(mynn.ReLU(true)) + block:add(ffnn.SpatialConvolution(nbChannels, + nbChannels, + filterSize, filterSize, + 1, 1, + (filterSize - 1) / 2, (filterSize - 1) / 2)) - block:add(mynn.SpatialConvolution(nbChannels, - nbChannels, - filterSize, filterSize, - 1, 1, - (filterSize - 1) / 2, (filterSize - 1) / 2)) + local parallel = ffnn.ConcatTable() + parallel:add(block):add(ffnn.Identity()) - local parallel = mynn.ConcatTable() - parallel:add(block):add(mynn.Identity()) + tower:add(parallel):add(ffnn.CAddTable(true)) - tower:add(parallel):add(mynn.CAddTable(true)) + tower:add(ffnn.SpatialBatchNormalization(nbChannels)) + tower:add(ffnn.ReLU(true)) + end - tower:add(mynn.SpatialBatchNormalization(nbChannels)) - tower:add(mynn.ReLU(true)) end return tower end -function createModel(filterSize, nbChannels, nbBlocks) - local model = mynn.Sequential() +function createModel(imageWidth, imageHeight, + filterSize, nbChannels, nbBlocks) + + local model = ffnn.Sequential() - model:add(mynn.SpatialConvolution(2, + -- Encode the two input channels (grasping image and starting + -- configuration) into the internal number of channels + model:add(ffnn.SpatialConvolution(2, nbChannels, filterSize, filterSize, 1, 1, (filterSize - 1) / 2, (filterSize - 1) / 2)) - model:add(mynn.SpatialBatchNormalization(nbChannels)) - model:add(mynn.ReLU(true)) - - local towerCode = createTower(filterSize, nbChannels, nbBlocks) - local towerDecode = createTower(filterSize, nbChannels, nbBlocks) + model:add(ffnn.SpatialBatchNormalization(nbChannels)) + model:add(ffnn.ReLU(true)) - model:add(towerCode) - model:add(towerDecode) + -- Add the resnet modules + model:add(createTower(filterSize, nbChannels, nbBlocks)) - -- Decode to a single channel, which is the final image - model:add(mynn.SpatialConvolution(nbChannels, + -- Decode down to a single channel, which is the final image + model:add(ffnn.SpatialConvolution(nbChannels, 1, filterSize, filterSize, 1, 1, @@ -539,34 +320,10 @@ end ---------------------------------------------------------------------- -function fillBatch(data, first, nb, batch, permutation) - for k = 1, nb do - local i - if permutation then - i = permutation[first + k - 1] - else - i = first + k - 1 - end - batch.input[k] = data.input[i] - batch.target[k] = data.target[i] - end -end - -function trainModel(model, - trainData, validationData, nbEpochs, learningRate, - learningStateFile) - - local l2criterion = nn.MSECriterion() - local batchSize = config.batchSize - - if useGPU then - logString('Moving the criterion to the GPU.\n') - l2criterion:cuda() - end +function trainModel(model, trainSet, validationSet) - local batch = {} - batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width) - batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width) + local criterion = nn.MSECriterion() + local batchSize = params.batchSize local startingEpoch = 1 @@ -575,47 +332,59 @@ function trainModel(model, end if model.RNGState then + printfc(colors.red, 'Using the RNG state from the loaded model.') torch.setRNGState(model.RNGState) end - logString('Starting training.\n') + if params.useGPU then + print('Moving the model and criterion to the GPU.') + model:cuda() + criterion:cuda() + end + + print('Starting training.') local parameters, gradParameters = model:getParameters() - logString(string.format('model has %d parameters.\n', parameters:storage():size(1))) + printf('The model has %d parameters.', parameters:storage():size(1)) local averageTrainLoss, averageValidationLoss local trainTime, validationTime + ---------------------------------------------------------------------- + local sgdState = { - learningRate = config.learningRate, - momentum = config.momentum, + learningRate = params.learningRate, + momentum = 0, learningRateDecay = 0 } - for e = startingEpoch, nbEpochs do + local batch = {} + + for e = startingEpoch, params.nbEpochs do model:training() - local permutation = torch.randperm(trainData.nbSamples) + local permutation = torch.randperm(trainSet.nbSamples) local accLoss = 0.0 local nbBatches = 0 local startTime = sys.clock() - for b = 1, trainData.nbSamples, batchSize do + for b = 1, trainSet.nbSamples, batchSize do - fillBatch(trainData, b, batchSize, batch, permutation) + fillBatch(trainSet, b, batch, permutation) local opfunc = function(x) - -- Surprisingly copy() needs this check + -- Surprisingly, copy() needs this check if x ~= parameters then parameters:copy(x) end local output = model:forward(batch.input) - local loss = l2criterion:forward(output, batch.target) - local dLossdOutput = l2criterion:backward(output, batch.target) + local loss = criterion:forward(output, batch.target) + local dLossdOutput = criterion:backward(output, batch.target) + gradParameters:zero() model:backward(batch.input, dLossdOutput) @@ -634,6 +403,7 @@ function trainModel(model, ---------------------------------------------------------------------- -- Validation losses + do model:evaluate() @@ -641,10 +411,10 @@ function trainModel(model, local nbBatches = 0 local startTime = sys.clock() - for b = 1, validationData.nbSamples, batchSize do - fillBatch(validationData, b, batchSize, batch) + for b = 1, validationSet.nbSamples, batchSize do + fillBatch(validationSet, b, batch) local output = model:forward(batch.input) - accLoss = accLoss + l2criterion:forward(output, batch.target) + accLoss = accLoss + criterion:forward(output, batch.target) nbBatches = nbBatches + 1 end @@ -652,99 +422,91 @@ function trainModel(model, averageValidationLoss = accLoss / nbBatches; end - logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n', - trainTime, - 1000 * trainTime / trainData.nbSamples, - validationTime, - 1000 * validationTime / validationData.nbSamples)) + ---------------------------------------------------------------------- + + printfc(colors.green, + + 'epoch %d acc_train_loss %f validation_loss %f [train %.02fs total %.02fms / sample, validation %.02fs total %.02fms / sample]', + + e, + + averageTrainLoss, - logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss), - colors.green) + averageValidationLoss, + + trainTime, + 1000 * trainTime / trainSet.nbSamples, + + validationTime, + 1000 * validationTime / validationSet.nbSamples + ) ---------------------------------------------------------------------- -- Save a persistent state so that we can restart from there - if learningStateFile then - model.RNGState = torch.getRNGState() - model.epoch = e - model:clearState() - logString('Writing ' .. learningStateFile .. '.\n') - torch.save(learningStateFile, model) - end + model:clearState() + model.RNGState = torch.getRNGState() + model.epoch = e + torch.save(params.rundir .. '/model_last.t7', model) ---------------------------------------------------------------------- -- Save a duplicate of the persistent state from time to time - if opt.resultFreq > 0 and e%opt.resultFreq == 0 then - torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model) - saveResultImage(model, trainData) - saveResultImage(model, validationData) + if params.resultFreq > 0 and e%params.resultFreq == 0 then + torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model) + saveResultImage(model, trainSet) + saveResultImage(model, validationSet) end end end -function createAndTrainModel(trainData, validationData) - - local model - - local learningStateFile = opt.learningStateFile - - if learningStateFile == '' then - learningStateFile = opt.resultDir .. '/learning.state' - end +---------------------------------------------------------------------- +-- main - local gotlearningStateFile +local trainSet = loadData(1, + params.nbTrainSamples, 'train') - logString('Using the learning state file ' .. learningStateFile .. '\n') +local validationSet = loadData(params.nbTrainSamples + 1, + params.nbValidationSamples, 'validation') - if pcall(function () model = torch.load(learningStateFile) end) then +local model - gotlearningStateFile = true +if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then - else + printfc(colors.red, + 'Found a model with %d epochs completed, starting from there.', + model.epoch) - model = createModel(config.filterSize, config.nbChannels, config.nbBlocks) - - if useGPU then - logString('Moving the model to the GPU.\n') - model:cuda() + if params.exampleInternals ~= '' then + for _, i in ipairs(string.split(params.exampleInternals, ',')) do + saveInternalsImage(model, validationSet, tonumber(i)) end - - end - - logString(tostring(model) .. '\n') - - if gotlearningStateFile then - logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch), - colors.red) - end - - if opt.exampleInternals > 0 then - saveInternalsImage(model, validationData, opt.exampleInternals) os.exit(0) end - trainModel(model, - trainData, validationData, - config.nbEpochs, config.learningRate, - learningStateFile) +else - return model + model = createModel(trainSet.width, trainSet.height, + params.filterSize, params.nbChannels, + params.nbBlocks) end -for i, j in pairs(config) do - logString('config ' .. i .. ' = \'' .. j ..'\'\n') -end +trainModel(model, trainSet, validationSet) + +---------------------------------------------------------------------- +-- Test -local trainData = loadData(1, config.nbTrainSamples, 'train') -local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation') -local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test') +local testSet = loadData(params.nbTrainSamples + params.nbValidationSamples + 1, + params.nbTestSamples, 'test') -local model = createAndTrainModel(trainData, validationData) +if params.useGPU then + print('Moving the model and criterion to the GPU.') + model:cuda() +end -saveResultImage(model, trainData) -saveResultImage(model, validationData) -saveResultImage(model, testData, nil, testData.nbSamples) +saveResultImage(model, trainSet) +saveResultImage(model, validationSet) +saveResultImage(model, testSet, 1024)