Many fixes, now generates a single image per frame.
[dyncnn.git] / dyncnn.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    dyncnn is a deep-learning algorithm for the prediction of
6    interacting object dynamics
7
8    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9    Written by Francois Fleuret <francois.fleuret@idiap.ch>
10
11    This file is part of dyncnn.
12
13    dyncnn is free software: you can redistribute it and/or modify it
14    under the terms of the GNU General Public License version 3 as
15    published by the Free Software Foundation.
16
17    dyncnn is distributed in the hope that it will be useful, but
18    WITHOUT ANY WARRANTY; without even the implied warranty of
19    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20    General Public License for more details.
21
22    You should have received a copy of the GNU General Public License
23    along with dyncnn.  If not, see <http://www.gnu.org/licenses/>.
24
25 ]]--
26
27 require 'torch'
28 require 'nn'
29 require 'optim'
30 require 'image'
31 require 'pl'
32
33 require 'img'
34
35 ----------------------------------------------------------------------
36
37 function printf(f, ...)
38    print(string.format(f, unpack({...})))
39 end
40
41 colors = sys.COLORS
42
43 function printfc(c, f, ...)
44    printf(c .. string.format(f, unpack({...})) .. colors.black)
45 end
46
47 function logCommand(c)
48    print(colors.blue .. '[' .. c .. '] -> [' .. sys.execute(c) .. ']' .. colors.black)
49 end
50
51 ----------------------------------------------------------------------
52 -- Environment and command line arguments
53
54 local defaultNbThreads = 1
55 local defaultUseGPU = false
56
57 if os.getenv('TORCH_NB_THREADS') then
58    defaultNbThreads = os.getenv('TORCH_NB_THREADS')
59    print('Environment variable TORCH_NB_THREADS is set and equal to ' .. defaultNbThreads)
60 else
61    print('Environment variable TORCH_NB_THREADS is not set')
62 end
63
64 if os.getenv('TORCH_USE_GPU') then
65    defaultUseGPU = os.getenv('TORCH_USE_GPU') == 'yes'
66    print('Environment variable TORCH_USE_GPU is set and evaluated as ' .. tostring(defaultUseGPU))
67 else
68    print('Environment variable TORCH_USE_GPU is not set.')
69 end
70
71 ----------------------------------------------------------------------
72
73 local cmd = torch.CmdLine()
74
75 cmd:text('')
76 cmd:text('General setup')
77
78 cmd:option('-seed', 1, 'initial random seed')
79 cmd:option('-nbThreads', defaultNbThreads, 'how many threads (environment variable TORCH_NB_THREADS)')
80 cmd:option('-useGPU', defaultUseGPU, 'should we use cuda (environment variable TORCH_USE_GPU)')
81
82 cmd:text('')
83 cmd:text('Log')
84
85 cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images')
86 cmd:option('-exampleInternals', -1, 'should we save inner activation images')
87 cmd:option('-noLog', false, 'should we prevent logging')
88 cmd:option('-rundir', '', 'the directory for results')
89
90 cmd:text('')
91 cmd:text('Training')
92
93 cmd:option('-nbEpochs', 1000, 'nb of epochs for the heavy setting')
94 cmd:option('-learningRate', 0.1, 'learning rate')
95 cmd:option('-batchSize', 128, 'size of the mini-batches')
96 cmd:option('-filterSize', 5, 'convolution filter size')
97 cmd:option('-nbTrainSamples', 32768)
98 cmd:option('-nbValidationSamples', 1024)
99 cmd:option('-nbTestSamples', 1024)
100
101 cmd:text('')
102 cmd:text('Problem to solve')
103
104 cmd:option('-dataDir', './data/10p-mg', 'data directory')
105
106 cmd:text('')
107 cmd:text('Network structure')
108
109 cmd:option('-nbChannels', 16)
110 cmd:option('-nbBlocks', 8)
111
112 ------------------------------
113 -- Log and stuff
114
115 cmd:addTime('DYNCNN','%F %T')
116
117 params = cmd:parse(arg)
118
119 if params.rundir == '' then
120    params.rundir = cmd:string('exp', params, { })
121 end
122
123 paths.mkdir(params.rundir)
124
125 if not params.noLog then
126    -- Append to the log if there is one
127    cmd:log(io.open(params.rundir .. '/log', 'a'), params)
128 end
129
130 ----------------------------------------------------------------------
131 -- The experiment per se
132
133 if params.predictGrasp then
134    params.targetDepth = 2
135 else
136    params.targetDepth = 1
137 end
138
139 ----------------------------------------------------------------------
140 -- Initializations
141
142 torch.setnumthreads(params.nbThreads)
143 torch.setdefaulttensortype('torch.FloatTensor')
144 torch.manualSeed(params.seed)
145
146 ----------------------------------------------------------------------
147 -- Dealing with the CPU/GPU
148
149 -- mynn will take entries in that order: mynn, cudnn, cunn, nn
150
151 mynn = {}
152
153 setmetatable(mynn,
154              {
155                 __index = function(table, key)
156                    return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
157                 end
158              }
159 )
160
161 -- These are the tensors that can be kept on the CPU
162 mynn.SlowTensor = torch.Tensor
163
164 -- These are the tensors that should be moved to the GPU
165 mynn.FastTensor = torch.Tensor
166
167 if params.useGPU then
168    require 'cutorch'
169    require 'cunn'
170    require 'cudnn'
171    cudnn.benchmark = true
172    cudnn.fastest = true
173    mynn.FastTensor = torch.CudaTensor
174 end
175
176 ----------------------------------------------------------------------
177
178 function loadData(first, nb, name)
179    print('Loading data `' .. name .. '\'.')
180
181    local data = {}
182
183    data.name = name
184    data.nbSamples = nb
185    data.width = 64
186    data.height = 64
187
188    data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
189    data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
190
191    for i = 1, data.nbSamples do
192       local n = i-1 + first-1
193       local frame = image.load(string.format('%s/%03d/dyn_%06d.png',
194                                              params.dataDir,
195                                              math.floor(n/1000), n))
196
197       frame:mul(-1.0):add(1.0)
198       frame = frame:max(1):select(1, 1)
199
200       data.input[i][1]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
201                                       1 * data.width  + 1, 2 * data.width))
202
203       data.input[i][2]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
204                                       0 * data.width  + 1, 1 * data.width))
205
206       data.target[i][1]:copy(frame:sub(1 * data.height + 1, 2 * data.height,
207                                        1 * data.width  + 1, 2 * data.width))
208    end
209
210    return data
211 end
212
213 ----------------------------------------------------------------------
214
215 function collectAllOutputs(model, collection, which)
216    if torch.type(model) == 'nn.Sequential' then
217       for i = 1, #model.modules do
218          collectAllOutputs(model.modules[i], collection, which)
219       end
220    elseif not which or which[torch.type(model)] then
221       if torch.isTensor(model.output) then
222          collection.nb = collection.nb + 1
223          collection.outputs[collection.nb] = model.output
224       end
225    end
226 end
227
228 function saveInternalsImage(model, data, n)
229    -- Explicitely copy to keep input as a mynn.FastTensor
230    local input = mynn.FastTensor(1, 2, data.height, data.width)
231    input:copy(data.input:narrow(1, n, 1))
232
233    local output = model:forward(input)
234
235    local collection = {}
236    collection.outputs = {}
237    collection.nb = 1
238    collection.outputs[collection.nb] = input
239
240    collectAllOutputs(model, collection,
241                      {
242                         ['nn.ReLU'] = true,
243                         ['cunn.ReLU'] = true,
244                         ['cudnn.ReLU'] = true,
245                      }
246    )
247
248    if collection.outputs[collection.nb] ~= model.output then
249       collection.nb = collection.nb + 1
250       collection.outputs[collection.nb] = model.output
251    end
252
253    local fileName = string.format('%s/internals_%s_%06d.png',
254                                   params.rundir,
255                                   data.name, n)
256
257    print('Saving ' .. fileName)
258    image.save(fileName, imageFromTensors(collection.outputs))
259 end
260
261 ----------------------------------------------------------------------
262
263 function saveResultImage(model, data, nbMax)
264    local criterion = nn.MSECriterion()
265
266    if params.useGPU then
267       print('Moving the criterion to the GPU.')
268       criterion:cuda()
269    end
270
271    local input = mynn.FastTensor(1, 2, data.height, data.width)
272    local target = mynn.FastTensor(1, 1, data.height, data.width)
273
274    local nbMax = nbMax or 50
275
276    local nb = math.min(nbMax, data.nbSamples)
277
278    model:evaluate()
279
280    printf('Write %d result images for `%s\'.', nb, data.name)
281
282    local lossFile = io.open(params.rundir .. '/result_' .. data.name .. '_losses.dat', 'w')
283
284    for n = 1, nb do
285
286       -- Explicitely copy to keep input as a mynn.FastTensor
287       input:copy(data.input:narrow(1, n, 1))
288       target:copy(data.target:narrow(1, n, 1))
289
290       local output = model:forward(input)
291       local loss = criterion:forward(output, target)
292
293       output = mynn.SlowTensor(output:size()):copy(output)
294
295       -- We use our magical img.lua to create the result images
296
297       local comp = {
298          {
299             { pad = 1, data.input[n][1] },
300             { pad = 1, data.input[n][2] },
301             { pad = 1, data.target[n][1] },
302             { pad = 1, output[1][1] },
303          }
304       }
305
306       --[[
307       local comp = {
308          {
309             vertical = true,
310             { pad = 1, data.input[n][1] },
311             { pad = 1, data.input[n][2] }
312          },
313          torch.Tensor(4, 4):fill(1.0),
314          {
315             vertical = true,
316             { pad = 1, data.target[n][1] },
317             { pad = 1, output[1][1] },
318             { pad = 1, torch.csub(data.target[n][1], output[1][1]):abs() }
319          }
320       }
321       ]]--
322
323 local result = combineImages(1.0, comp)
324
325 result:mul(-1.0):add(1.0)
326
327 local fileName = string.format('result_%s_%06d.png', data.name, n)
328 image.save(params.rundir .. '/' .. fileName, result)
329 lossFile:write(string.format('%f %s\n', loss, fileName))
330 end
331 end
332
333 ----------------------------------------------------------------------
334
335 function createTower(filterSize, nbChannels, nbBlocks)
336
337    local tower
338
339    if nbBlocks == 0 then
340
341       tower = nn.Identity()
342
343    else
344
345       tower = mynn.Sequential()
346
347       for b = 1, nbBlocks do
348          local block = mynn.Sequential()
349
350          block:add(mynn.SpatialConvolution(nbChannels,
351                                            nbChannels,
352                                            filterSize, filterSize,
353                                            1, 1,
354                                            (filterSize - 1) / 2, (filterSize - 1) / 2))
355          block:add(mynn.SpatialBatchNormalization(nbChannels))
356          block:add(mynn.ReLU(true))
357
358          block:add(mynn.SpatialConvolution(nbChannels,
359                                            nbChannels,
360                                            filterSize, filterSize,
361                                            1, 1,
362                                            (filterSize - 1) / 2, (filterSize - 1) / 2))
363
364          local parallel = mynn.ConcatTable()
365          parallel:add(block):add(mynn.Identity())
366
367          tower:add(parallel):add(mynn.CAddTable(true))
368
369          tower:add(mynn.SpatialBatchNormalization(nbChannels))
370          tower:add(mynn.ReLU(true))
371       end
372
373    end
374
375    return tower
376
377 end
378
379 function createModel(imageWidth, imageHeight,
380                      filterSize, nbChannels, nbBlocks)
381
382    local model = mynn.Sequential()
383
384    -- Encode the two input channels (grasping image and starting
385    -- configuration) into the internal number of channels
386    model:add(mynn.SpatialConvolution(2,
387                                      nbChannels,
388                                      filterSize, filterSize,
389                                      1, 1,
390                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
391
392    model:add(mynn.SpatialBatchNormalization(nbChannels))
393    model:add(mynn.ReLU(true))
394
395    -- Add the resnet modules
396    model:add(createTower(filterSize, nbChannels, nbBlocks))
397
398    -- Decode down to a single channel, which is the final image
399    model:add(mynn.SpatialConvolution(nbChannels,
400                                      1,
401                                      filterSize, filterSize,
402                                      1, 1,
403                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
404
405    return model
406 end
407
408 ----------------------------------------------------------------------
409
410 function fillBatch(data, first, batch, permutation)
411    local actualBatchSize = math.min(params.batchSize, data.input:size(1) - first + 1)
412
413    if actualBatchSize ~= batch.input:size(1) then
414       local size = batch.input:size()
415       size[1] = actualBatchSize
416       batch.input:resize(size)
417    end
418
419    if actualBatchSize ~= batch.target:size(1) then
420       local size = batch.target:size()
421       size[1] = actualBatchSize
422       batch.target:resize(size)
423    end
424
425    for k = 1, batch.input:size(1) do
426       local i
427       if permutation then
428          i = permutation[first + k - 1]
429       else
430          i = first + k - 1
431       end
432       batch.input[k] = data.input[i]
433       batch.target[k] = data.target[i]
434    end
435 end
436
437 function trainModel(model, trainData, validationData)
438
439    local criterion = nn.MSECriterion()
440    local batchSize = params.batchSize
441
442    local batch = {}
443    batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
444    batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
445
446    local startingEpoch = 1
447
448    if model.epoch then
449       startingEpoch = model.epoch + 1
450    end
451
452    if model.RNGState then
453       torch.setRNGState(model.RNGState)
454    end
455
456    if params.useGPU then
457       print('Moving the model and criterion to the GPU.')
458       model:cuda()
459       criterion:cuda()
460    end
461
462    print('Starting training.')
463
464    local parameters, gradParameters = model:getParameters()
465    printf('The model has %d parameters.', parameters:storage():size(1))
466
467    local averageTrainLoss, averageValidationLoss
468    local trainTime, validationTime
469
470    ----------------------------------------------------------------------
471
472    local sgdState = {
473       learningRate = params.learningRate,
474       momentum = 0,
475       learningRateDecay = 0
476    }
477
478    for e = startingEpoch, params.nbEpochs do
479
480       model:training()
481
482       local permutation = torch.randperm(trainData.nbSamples)
483
484       local accLoss = 0.0
485       local nbBatches = 0
486       local startTime = sys.clock()
487
488       for b = 1, trainData.nbSamples, batchSize do
489
490          fillBatch(trainData, b, batch, permutation)
491
492          local opfunc = function(x)
493             -- Surprisingly, copy() needs this check
494             if x ~= parameters then
495                parameters:copy(x)
496             end
497
498             local output = model:forward(batch.input)
499
500             local loss = criterion:forward(output, batch.target)
501             local dLossdOutput = criterion:backward(output, batch.target)
502
503             gradParameters:zero()
504             model:backward(batch.input, dLossdOutput)
505
506             accLoss = accLoss + loss
507             nbBatches = nbBatches + 1
508
509             return loss, gradParameters
510          end
511
512          optim.sgd(opfunc, parameters, sgdState)
513
514       end
515
516       trainTime = sys.clock() - startTime
517       averageTrainLoss = accLoss / nbBatches
518
519       ----------------------------------------------------------------------
520       -- Validation losses
521
522       do
523          model:evaluate()
524
525          local accLoss = 0.0
526          local nbBatches = 0
527          local startTime = sys.clock()
528
529          for b = 1, validationData.nbSamples, batchSize do
530             fillBatch(validationData, b, batch)
531             local output = model:forward(batch.input)
532             accLoss = accLoss + criterion:forward(output, batch.target)
533             nbBatches = nbBatches + 1
534          end
535
536          validationTime = sys.clock() - startTime
537          averageValidationLoss = accLoss / nbBatches;
538       end
539
540       printf('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).',
541              trainTime,
542              1000 * trainTime / trainData.nbSamples,
543              validationTime,
544              1000 * validationTime / validationData.nbSamples)
545
546       printfc(colors.green, 'LOSS %d %f %f', e, averageTrainLoss, averageValidationLoss)
547
548       ----------------------------------------------------------------------
549       -- Save a persistent state so that we can restart from there
550
551       model:clearState()
552       model.RNGState = torch.getRNGState()
553       model.epoch = e
554       torch.save(params.rundir .. '/model_last.t7', model)
555
556       ----------------------------------------------------------------------
557       -- Save a duplicate of the persistent state from time to time
558
559       if params.resultFreq > 0 and e%params.resultFreq == 0 then
560          torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model)
561          saveResultImage(model, trainData)
562          saveResultImage(model, validationData)
563       end
564
565    end
566
567 end
568
569 function createAndTrainModel(trainData, validationData)
570
571    -- Load the current training state, or create a new model from
572    -- scratch
573
574    if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then
575
576       printfc(colors.red,
577               'Found a learning state with %d epochs finished, starting from there.',
578               model.epoch)
579
580       if params.exampleInternals > 0 then
581          saveInternalsImage(model, validationData, params.exampleInternals)
582          os.exit(0)
583       end
584
585    else
586
587       model = createModel(trainData.width, trainData.height,
588                           params.filterSize, params.nbChannels,
589                           params.nbBlocks)
590
591    end
592
593    trainModel(model, trainData, validationData)
594
595    return model
596
597 end
598
599 ----------------------------------------------------------------------
600 -- main
601
602 for _, c in pairs({
603       'date',
604       'uname -a',
605       'git log -1 --format=%H'
606                  })
607 do
608    logCommand(c)
609 end
610
611 local trainData = loadData(1,
612                            params.nbTrainSamples, 'train')
613
614 local validationData = loadData(params.nbTrainSamples + 1,
615                                 params.nbValidationSamples, 'validation')
616
617 local model = createAndTrainModel(trainData, validationData)
618
619 ----------------------------------------------------------------------
620 -- Test
621
622 local testData = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
623                           params.nbTestSamples, 'test')
624
625 if params.useGPU then
626    print('Moving the model and criterion to the GPU.')
627    model:cuda()
628 end
629
630 saveResultImage(model, trainData)
631 saveResultImage(model, validationData)
632 saveResultImage(model, testData, 1024)