The validation error was computed on the training data. That's embarassing.
[dyncnn.git] / dyncnn.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    dyncnn is a deep-learning algorithm for the prediction of
6    interacting object dynamics
7
8    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9    Written by Francois Fleuret <francois.fleuret@idiap.ch>
10
11    This file is part of dyncnn.
12
13    dyncnn is free software: you can redistribute it and/or modify it
14    under the terms of the GNU General Public License version 3 as
15    published by the Free Software Foundation.
16
17    dyncnn is distributed in the hope that it will be useful, but
18    WITHOUT ANY WARRANTY; without even the implied warranty of
19    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20    General Public License for more details.
21
22    You should have received a copy of the GNU General Public License
23    along with dyncnn.  If not, see <http://www.gnu.org/licenses/>.
24
25 ]]--
26
27 require 'torch'
28 require 'nn'
29 require 'optim'
30 require 'image'
31 require 'pl'
32
33 ----------------------------------------------------------------------
34
35 local opt = lapp[[
36    --seed                (default 1)               random seed
37
38    --learningStateFile   (default '')
39    --dataDir             (default './data/10p-mg/')
40    --resultDir           (default '/tmp/dyncnn')
41
42    --learningRate        (default -1)
43    --momentum            (default -1)
44    --nbEpochs            (default -1)              nb of epochs for the heavy setting
45
46    --heavy                                         use the heavy configuration
47    --nbChannels          (default -1)              nb of channels in the internal layers
48    --resultFreq          (default 100)
49
50    --noLog                                         supress logging
51
52    --exampleInternals    (default -1)
53 ]]
54
55 ----------------------------------------------------------------------
56
57 commandLine=''
58 for i = 0, #arg do
59    commandLine = commandLine ..  ' \'' .. arg[i] .. '\''
60 end
61
62 ----------------------------------------------------------------------
63
64 colors = sys.COLORS
65
66 global = {}
67
68 function logString(s, c)
69    if global.logFile then
70       global.logFile:write(s)
71       global.logFile:flush()
72    end
73    local c = c or colors.black
74    io.write(c .. s)
75    io.flush()
76 end
77
78 function logCommand(c)
79    logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue)
80 end
81
82 logString('commandline: ' .. commandLine .. '\n', colors.blue)
83
84 logCommand('mkdir -v -p ' .. opt.resultDir)
85
86 if not opt.noLog then
87    global.logName = opt.resultDir .. '/log'
88    global.logFile = io.open(global.logName, 'a')
89 end
90
91 ----------------------------------------------------------------------
92
93 alreadyLoggedString = {}
94
95 function logOnce(s)
96    local l = debug.getinfo(1).currentline
97    if not alreadyLoggedString[l] then
98       logString('@line ' .. l .. ' ' .. s, colors.red)
99       alreadyLoggedString[l] = s
100    end
101 end
102
103 ----------------------------------------------------------------------
104
105 nbThreads = os.getenv('TORCH_NB_THREADS') or 1
106
107 useGPU = os.getenv('TORCH_USE_GPU') == 'yes'
108
109 for _, c in pairs({ 'date',
110                     'uname -a',
111                     'git log -1 --format=%H'
112                  })
113 do
114    logCommand(c)
115 end
116
117 logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n')
118
119 logString('nbThreads is \'' .. nbThreads .. '\'.\n')
120
121 ----------------------------------------------------------------------
122
123 torch.setnumthreads(nbThreads)
124 torch.setdefaulttensortype('torch.FloatTensor')
125 torch.manualSeed(opt.seed)
126
127 mynn = {}
128
129 -- By default, mynn returns the entries from nn
130 local mt = {}
131 function mt.__index(table, key)
132    return nn[key]
133 end
134 setmetatable(mynn, mt)
135
136 -- These are the tensors that can be kept on the CPU
137 mynn.SlowTensor = torch.Tensor
138 -- These are the tensors that should be moved to the GPU
139 mynn.FastTensor = torch.Tensor
140
141 ----------------------------------------------------------------------
142
143 if useGPU then
144    require 'cutorch'
145    require 'cunn'
146    require 'cudnn'
147    mynn.FastTensor = torch.CudaTensor
148    mynn.SpatialConvolution = cudnn.SpatialConvolution
149 end
150
151 ----------------------------------------------------------------------
152
153 config = {}
154 config.learningRate = 0.1
155 config.momentum = 0
156 config.batchSize = 128
157 config.filterSize = 5
158
159 if opt.heavy then
160
161    logString('Using the heavy configuration.\n')
162    config.nbChannels = 16
163    config.nbBlocks = 4
164    config.nbEpochs = 250
165    config.nbEpochsInit = 100
166    config.nbTrainSamples = 32768
167    config.nbValidationSamples = 1024
168    config.nbTestSamples = 1024
169
170 else
171
172    logString('Using the light configuration.\n')
173    config.nbChannels = 2
174    config.nbBlocks = 2
175    config.nbEpochs = 6
176    config.nbEpochsInit = 3
177    config.nbTrainSamples = 1024
178    config.nbValidationSamples = 1024
179    config.nbTestSamples = 1024
180
181 end
182
183 if opt.nbEpochs > 0 then
184    config.nbEpochs = opt.nbEpochs
185 end
186
187 if opt.nbChannels > 0 then
188    config.nbChannels = opt.nbChannels
189 end
190
191 if opt.learningRate > 0 then
192    config.learningRate = opt.learningRate
193 end
194
195 if opt.momentum >= 0 then
196    config.momentum = opt.momentum
197 end
198
199 ----------------------------------------------------------------------
200
201 function tensorCensus(tensorType, model)
202
203    local nb = {}
204
205    local function countThings(m)
206       for k, i in pairs(m) do
207          if torch.type(i) == tensorType then
208             nb[k] = (nb[k] or 0) + i:nElement()
209          end
210       end
211    end
212
213    model:apply(countThings)
214
215    return nb
216
217 end
218
219 ----------------------------------------------------------------------
220
221 function loadData(first, nb, name)
222    logString('Loading data `' .. name .. '\'.\n')
223
224    local persistentFileName = string.format('%s/persistent_%d_%d.dat',
225                                             opt.dataDir,
226                                             first,
227                                             nb)
228
229    -- This is at what framerate we work. It is greater than 1 so that
230    -- we can keep on disk sequences at a higher frame rate for videos
231    -- and explaining materials
232
233    local frameRate = 4
234
235    local data
236
237    if not path.exists(persistentFileName) then
238       logString(string.format('No persistent data structure, creating it (%d samples).\n', nb))
239       local data = {}
240       data.name = name
241       data.nbSamples = nb
242       data.width = 64
243       data.height = 64
244       data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
245       data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
246
247       for i = 1, data.nbSamples do
248          local n = i-1 + first-1
249          local prefix = string.format('%s/%03d/dyn_%06d',
250                                       opt.dataDir,
251                                       math.floor(n/1000), n)
252
253          function localLoad(filename, tensor)
254             local tmp
255             tmp = image.load(filename)
256             tmp:mul(-1.0):add(1.0)
257             tensor:copy(torch.max(tmp, 1))
258          end
259
260          localLoad(prefix .. '_world_000.png', data.input[i][1])
261          localLoad(prefix .. '_grab.png',    data.input[i][2])
262          localLoad(string.format('%s_world_%03d.png', prefix, frameRate),
263                    data.target[i][1])
264       end
265
266       data.persistentFileName = persistentFileName
267
268       torch.save(persistentFileName, data)
269    end
270
271    logCommand('sha256sum -b ' .. persistentFileName)
272
273    data = torch.load(persistentFileName)
274
275    return data
276 end
277
278 ----------------------------------------------------------------------
279
280 -- This function gets as input a list of tensors of arbitrary
281 -- dimensions each, but whose two last dimension stands for height x
282 -- width. It creates an image tensor (2d, one channel) with each
283 -- argument tensor unfolded per row.
284
285 function imageFromTensors(bt, signed)
286    local gap = 1
287    local tgap = -1
288    local width = 0
289    local height = gap
290
291    for _, t in pairs(bt) do
292       -- print(t:size())
293       local d = t:dim()
294       local h, w = t:size(d - 1), t:size(d)
295       local n = t:nElement() / (w * h)
296       width = math.max(width, gap + n * (gap + w))
297       height = height + gap + tgap + gap + h
298    end
299
300    local e = torch.Tensor(3, height, width):fill(1.0)
301    local y0 = 1 + gap
302
303    for _, t in pairs(bt) do
304       local d = t:dim()
305       local h, w = t:size(d - 1), t:size(d)
306       local n = t:nElement() / (w * h)
307       local z = t:norm() / math.sqrt(t:nElement())
308
309       local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 )
310       local u = torch.Tensor(t:size()):copy(t):resize(n, h, w)
311       for m = 1, n do
312
313          for c = 1, 3 do
314             for y = 0, h+1 do
315                e[c][y0 + y - 1][x0     - 1] = 0.0
316                e[c][y0 + y - 1][x0 + w    ] = 0.0
317             end
318             for x = 0, w+1 do
319                e[c][y0     - 1][x0 + x - 1] = 0.0
320                e[c][y0 + h    ][x0 + x - 1] = 0.0
321             end
322          end
323
324          for y = 1, h do
325             for x = 1, w do
326                local v = u[m][y][x] / z
327                local r, g, b
328                if signed then
329                   if v < -1 then
330                      r, g, b = 0.0, 0.0, 1.0
331                   elseif v > 1 then
332                      r, g, b = 1.0, 0.0, 0.0
333                   elseif v >= 0 then
334                      r, g, b = 1.0, 1.0 - v, 1.0 - v
335                   else
336                      r, g, b = 1.0 + v, 1.0 + v, 1.0
337                   end
338                else
339                   if v <= 0 then
340                      r, g, b = 1.0, 1.0, 1.0
341                   elseif v > 1 then
342                      r, g, b = 0.0, 0.0, 0.0
343                   else
344                      r, g, b = 1.0 - v, 1.0 - v, 1.0 - v
345                   end
346                end
347                e[1][y0 + y - 1][x0 + x - 1] = r
348                e[2][y0 + y - 1][x0 + x - 1] = g
349                e[3][y0 + y - 1][x0 + x - 1] = b
350             end
351          end
352          x0 = x0 + w + gap
353       end
354       y0 = y0 + h + gap + tgap + gap
355    end
356
357    return e
358 end
359
360 function collectAllOutputs(model, collection, which)
361    if torch.type(model) == 'nn.Sequential' then
362       for i = 1, #model.modules do
363          collectAllOutputs(model.modules[i], collection, which)
364       end
365    elseif not which or which[torch.type(model)] then
366       local t = torch.type(model.output)
367       if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then
368          collection.nb = collection.nb + 1
369          collection.outputs[collection.nb] = model.output
370       end
371    end
372 end
373
374 function saveInternalsImage(model, data, n)
375    -- Explicitely copy to keep input as a mynn.FastTensor
376    local input = mynn.FastTensor(1, 2, data.height, data.width)
377    input:copy(data.input:narrow(1, n, 1))
378
379    local output = model:forward(input)
380
381    local collection = {}
382    collection.outputs = {}
383    collection.nb = 1
384    collection.outputs[collection.nb] = input
385
386    local which = {}
387    which['nn.ReLU'] = true
388    collectAllOutputs(model, collection, which)
389
390    if collection.outputs[collection.nb] ~= model.output then
391       collection.nb = collection.nb + 1
392       collection.outputs[collection.nb] = model.output
393    end
394
395    local fileName = string.format('%s/internals_%s_%06d.png',
396                                   opt.resultDir,
397                                   data.name, n)
398
399    logString('Saving ' .. fileName .. '\n')
400    image.save(fileName, imageFromTensors(collection.outputs))
401 end
402
403 ----------------------------------------------------------------------
404
405 function saveResultImage(model, data, prefix, nbMax, highlight)
406    local l2criterion = nn.MSECriterion()
407
408    if useGPU then
409       logString('Moving the criterion to the GPU.\n')
410       l2criterion:cuda()
411    end
412
413    local prefix = prefix or 'result'
414    local result = torch.Tensor(data.height * 4 + 5, data.width + 2)
415    local input = mynn.FastTensor(1, 2, data.height, data.width)
416    local target = mynn.FastTensor(1, 1, data.height, data.width)
417
418    local nbMax = nbMax or 50
419
420    local nb = math.min(nbMax, data.nbSamples)
421
422    model:evaluate()
423
424    logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n',
425                            nb, prefix, data.name,
426                            opt.resultDir))
427
428    for n = 1, nb do
429
430       -- Explicitely copy to keep input as a mynn.FastTensor
431       input:copy(data.input:narrow(1, n, 1))
432       target:copy(data.target:narrow(1, n, 1))
433
434       local output = model:forward(input)
435
436       local loss = l2criterion:forward(output, target)
437
438       result:fill(1.0)
439
440       if highlight then
441          for i = 1, data.height do
442             for j = 1, data.width do
443                local v = data.input[n][1][i][j]
444                result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
445                result[1 + i + 1 * (data.height + 1)][1 + j] = v
446                local a = data.target[n][1][i][j]
447                local b = output[1][1][i][j]
448                result[1 + i + 2 * (data.height + 1)][1 + j] =
449                   a * math.min(1, 0.1 + 2.0 * math.abs(a - v))
450                result[1 + i + 3 * (data.height + 1)][1 + j] =
451                   b * math.min(1, 0.1 + 2.0 * math.abs(b - v))
452             end
453          end
454       else
455          for i = 1, data.height do
456             for j = 1, data.width do
457                result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
458                result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j]
459                result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j]
460                result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j]
461             end
462          end
463       end
464
465       result:mul(-1.0):add(1.0)
466
467       local fileName = string.format('%s/%s_%s_%06d.png',
468                                      opt.resultDir,
469                                      prefix,
470                                      data.name, n)
471
472       logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName))
473
474       image.save(fileName, result)
475    end
476 end
477
478 ----------------------------------------------------------------------
479
480 function createTower(filterSize, nbChannels, nbBlocks)
481    local tower = mynn.Sequential()
482
483    for b = 1, nbBlocks do
484       local block = mynn.Sequential()
485
486       block:add(mynn.SpatialConvolution(nbChannels,
487                                         nbChannels,
488                                         filterSize, filterSize,
489                                         1, 1,
490                                         (filterSize - 1) / 2, (filterSize - 1) / 2))
491       block:add(mynn.SpatialBatchNormalization(nbChannels))
492       block:add(mynn.ReLU(true))
493
494       block:add(mynn.SpatialConvolution(nbChannels,
495                                         nbChannels,
496                                         filterSize, filterSize,
497                                         1, 1,
498                                         (filterSize - 1) / 2, (filterSize - 1) / 2))
499
500       local parallel = mynn.ConcatTable()
501       parallel:add(block):add(mynn.Identity())
502
503       tower:add(parallel):add(mynn.CAddTable(true))
504
505       tower:add(mynn.SpatialBatchNormalization(nbChannels))
506       tower:add(mynn.ReLU(true))
507    end
508
509    return tower
510 end
511
512 function createModel(filterSize, nbChannels, nbBlocks)
513    local model = mynn.Sequential()
514
515    model:add(mynn.SpatialConvolution(2,
516                                      nbChannels,
517                                      filterSize, filterSize,
518                                      1, 1,
519                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
520
521    model:add(mynn.SpatialBatchNormalization(nbChannels))
522    model:add(mynn.ReLU(true))
523
524    local towerCode   = createTower(filterSize, nbChannels, nbBlocks)
525    local towerDecode = createTower(filterSize, nbChannels, nbBlocks)
526
527    model:add(towerCode)
528    model:add(towerDecode)
529
530    -- Decode to a single channel, which is the final image
531    model:add(mynn.SpatialConvolution(nbChannels,
532                                      1,
533                                      filterSize, filterSize,
534                                      1, 1,
535                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
536
537    return model
538 end
539
540 ----------------------------------------------------------------------
541
542 function fillBatch(data, first, nb, batch, permutation)
543    for k = 1, nb do
544       local i
545       if permutation then
546          i = permutation[first + k - 1]
547       else
548          i = first + k - 1
549       end
550       batch.input[k] = data.input[i]
551       batch.target[k] = data.target[i]
552    end
553 end
554
555 function trainModel(model,
556                     trainData, validationData, nbEpochs, learningRate,
557                     learningStateFile)
558
559    local l2criterion = nn.MSECriterion()
560    local batchSize = config.batchSize
561
562    if useGPU then
563       logString('Moving the criterion to the GPU.\n')
564       l2criterion:cuda()
565    end
566
567    local batch = {}
568    batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
569    batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
570
571    local startingEpoch = 1
572
573    if model.epoch then
574       startingEpoch = model.epoch + 1
575    end
576
577    if model.RNGState then
578       torch.setRNGState(model.RNGState)
579    end
580
581    logString('Starting training.\n')
582
583    local parameters, gradParameters = model:getParameters()
584    logString(string.format('model has %d parameters.\n', parameters:storage():size(1)))
585
586    local averageTrainLoss, averageValidationLoss
587    local trainTime, validationTime
588
589    local sgdState = {
590       learningRate = config.learningRate,
591       momentum = config.momentum,
592       learningRateDecay = 0
593    }
594
595    for e = startingEpoch, nbEpochs do
596
597       model:training()
598
599       local permutation = torch.randperm(trainData.nbSamples)
600
601       local accLoss = 0.0
602       local nbBatches = 0
603       local startTime = sys.clock()
604
605       for b = 1, trainData.nbSamples, batchSize do
606
607          fillBatch(trainData, b, batchSize, batch, permutation)
608
609          local opfunc = function(x)
610             -- Surprisingly copy() needs this check
611             if x ~= parameters then
612                parameters:copy(x)
613             end
614
615             local output = model:forward(batch.input)
616             local loss = l2criterion:forward(output, batch.target)
617
618             local dLossdOutput = l2criterion:backward(output, batch.target)
619             gradParameters:zero()
620             model:backward(batch.input, dLossdOutput)
621
622             accLoss = accLoss + loss
623             nbBatches = nbBatches + 1
624
625             return loss, gradParameters
626          end
627
628          optim.sgd(opfunc, parameters, sgdState)
629
630       end
631
632       trainTime = sys.clock() - startTime
633       averageTrainLoss = accLoss / nbBatches
634
635       ----------------------------------------------------------------------
636       -- Validation losses
637       do
638          model:evaluate()
639
640          local accLoss = 0.0
641          local nbBatches = 0
642          local startTime = sys.clock()
643
644          for b = 1, validationData.nbSamples, batchSize do
645             fillBatch(validationData, b, batchSize, batch)
646             local output = model:forward(batch.input)
647             accLoss = accLoss + l2criterion:forward(output, batch.target)
648             nbBatches = nbBatches + 1
649          end
650
651          validationTime = sys.clock() - startTime
652          averageValidationLoss = accLoss / nbBatches;
653       end
654
655       logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n',
656                               trainTime,
657                               1000 * trainTime / trainData.nbSamples,
658                               validationTime,
659                               1000 * validationTime / validationData.nbSamples))
660
661       logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss),
662                 colors.green)
663
664       ----------------------------------------------------------------------
665       -- Save a persistent state so that we can restart from there
666
667       if learningStateFile then
668          model.RNGState = torch.getRNGState()
669          model.epoch = e
670          model:clearState()
671          logString('Writing ' .. learningStateFile .. '.\n')
672          torch.save(learningStateFile, model)
673       end
674
675       ----------------------------------------------------------------------
676       -- Save a duplicate of the persistent state from time to time
677
678       if opt.resultFreq > 0 and e%opt.resultFreq == 0 then
679          torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model)
680          saveResultImage(model, trainData)
681          saveResultImage(model, validationData)
682       end
683
684    end
685
686 end
687
688 function createAndTrainModel(trainData, validationData)
689
690    local model
691
692    local learningStateFile = opt.learningStateFile
693
694    if learningStateFile == '' then
695       learningStateFile = opt.resultDir .. '/learning.state'
696    end
697
698    local gotlearningStateFile
699
700    logString('Using the learning state file ' .. learningStateFile .. '\n')
701
702    if pcall(function () model = torch.load(learningStateFile) end) then
703
704       gotlearningStateFile = true
705
706    else
707
708       model = createModel(config.filterSize, config.nbChannels, config.nbBlocks)
709
710       if useGPU then
711          logString('Moving the model to the GPU.\n')
712          model:cuda()
713       end
714
715    end
716
717    logString(tostring(model) .. '\n')
718
719    if gotlearningStateFile then
720       logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch),
721                 colors.red)
722    end
723
724    if opt.exampleInternals > 0 then
725       saveInternalsImage(model, validationData, opt.exampleInternals)
726       os.exit(0)
727    end
728
729    trainModel(model,
730               trainData, validationData,
731               config.nbEpochs, config.learningRate,
732               learningStateFile)
733
734    return model
735
736 end
737
738 for i, j in pairs(config) do
739    logString('config ' .. i .. ' = \'' .. j ..'\'\n')
740 end
741
742 local trainData = loadData(1, config.nbTrainSamples, 'train')
743 local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation')
744 local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test')
745
746 local model = createAndTrainModel(trainData, validationData)
747
748 saveResultImage(model, trainData)
749 saveResultImage(model, validationData)
750 saveResultImage(model, testData, nil, testData.nbSamples)