Making an even deeper model.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28 import re
29
30 from colorama import Fore, Back, Style
31
32 # Pytorch
33
34 import torch
35 import torchvision
36
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
41 from torch import nn
42 from torch.nn import functional as fn
43
44 from torchvision import datasets, transforms, utils
45
46 # SVRT
47
48 import svrtset
49
50 ######################################################################
51
52 parser = argparse.ArgumentParser(
53     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54     formatter_class = argparse.ArgumentDefaultsHelpFormatter
55 )
56
57 parser.add_argument('--nb_train_samples',
58                     type = int, default = 100000)
59
60 parser.add_argument('--nb_test_samples',
61                     type = int, default = 10000)
62
63 parser.add_argument('--nb_validation_samples',
64                     type = int, default = 10000)
65
66 parser.add_argument('--validation_error_threshold',
67                     type = float, default = 0.0,
68                     help = 'Early training termination criterion')
69
70 parser.add_argument('--nb_epochs',
71                     type = int, default = 50)
72
73 parser.add_argument('--batch_size',
74                     type = int, default = 100)
75
76 parser.add_argument('--log_file',
77                     type = str, default = 'default.log')
78
79 parser.add_argument('--nb_exemplar_vignettes',
80                     type = int, default = 32)
81
82 parser.add_argument('--compress_vignettes',
83                     type = distutils.util.strtobool, default = 'True',
84                     help = 'Use lossless compression to reduce the memory footprint')
85
86 parser.add_argument('--model',
87                     type = str, default = 'deepnet',
88                     help = 'What model to use')
89
90 parser.add_argument('--test_loaded_models',
91                     type = distutils.util.strtobool, default = 'False',
92                     help = 'Should we compute the test errors of loaded models')
93
94 parser.add_argument('--problems',
95                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96                     help = 'What problems to process')
97
98 args = parser.parse_args()
99
100 ######################################################################
101
102 log_file = open(args.log_file, 'a')
103 pred_log_t = None
104 last_tag_t = time.time()
105
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
107
108 # Log and prints the string, with a time stamp. Does not log the
109 # remark
110
111 def log_string(s, remark = ''):
112     global pred_log_t, last_tag_t
113
114     t = time.time()
115
116     if pred_log_t is None:
117         elapsed = 'start'
118     else:
119         elapsed = '+{:.02f}s'.format(t - pred_log_t)
120
121     pred_log_t = t
122
123     if t > last_tag_t + 3600:
124         last_tag_t = t
125         print(Fore.RED + time.ctime() + Style.RESET_ALL)
126
127     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
128     log_file.flush()
129
130     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
131
132 ######################################################################
133
134 # Afroze's ShallowNet
135
136 #                       map size   nb. maps
137 #                     ----------------------
138 #    input                128x128    1
139 # -- conv(21x21 x 6)   -> 108x108    6
140 # -- max(2x2)          -> 54x54      6
141 # -- conv(19x19 x 16)  -> 36x36      16
142 # -- max(2x2)          -> 18x18      16
143 # -- conv(18x18 x 120) -> 1x1        120
144 # -- reshape           -> 120        1
145 # -- full(120x84)      -> 84         1
146 # -- full(84x2)        -> 2          1
147
148 class AfrozeShallowNet(nn.Module):
149     name = 'shallownet'
150
151     def __init__(self):
152         super(AfrozeShallowNet, self).__init__()
153         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
154         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
155         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
156         self.fc1 = nn.Linear(120, 84)
157         self.fc2 = nn.Linear(84, 2)
158
159     def forward(self, x):
160         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
161         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
162         x = fn.relu(self.conv3(x))
163         x = x.view(-1, 120)
164         x = fn.relu(self.fc1(x))
165         x = self.fc2(x)
166         return x
167
168 ######################################################################
169
170 # Afroze's DeepNet
171
172 class AfrozeDeepNet(nn.Module):
173
174     name = 'deepnet'
175
176     def __init__(self):
177         super(AfrozeDeepNet, self).__init__()
178         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
179         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
180         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
181         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
182         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
183         self.fc1 = nn.Linear(1536, 256)
184         self.fc2 = nn.Linear(256, 256)
185         self.fc3 = nn.Linear(256, 2)
186
187     def forward(self, x):
188         x = self.conv1(x)
189         x = fn.max_pool2d(x, kernel_size=2)
190         x = fn.relu(x)
191
192         x = self.conv2(x)
193         x = fn.max_pool2d(x, kernel_size=2)
194         x = fn.relu(x)
195
196         x = self.conv3(x)
197         x = fn.relu(x)
198
199         x = self.conv4(x)
200         x = fn.relu(x)
201
202         x = self.conv5(x)
203         x = fn.max_pool2d(x, kernel_size=2)
204         x = fn.relu(x)
205
206         x = x.view(-1, 1536)
207
208         x = self.fc1(x)
209         x = fn.relu(x)
210
211         x = self.fc2(x)
212         x = fn.relu(x)
213
214         x = self.fc3(x)
215
216         return x
217
218 ######################################################################
219
220 class DeepNet2(nn.Module):
221     name = 'deepnet2'
222
223     def __init__(self):
224         super(DeepNet2, self).__init__()
225         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
226         self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
227         self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
228         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
229         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
230         self.fc1 = nn.Linear(4096, 512)
231         self.fc2 = nn.Linear(512, 512)
232         self.fc3 = nn.Linear(512, 2)
233
234     def forward(self, x):
235         x = self.conv1(x)
236         x = fn.max_pool2d(x, kernel_size=2)
237         x = fn.relu(x)
238
239         x = self.conv2(x)
240         x = fn.max_pool2d(x, kernel_size=2)
241         x = fn.relu(x)
242
243         x = self.conv3(x)
244         x = fn.relu(x)
245
246         x = self.conv4(x)
247         x = fn.relu(x)
248
249         x = self.conv5(x)
250         x = fn.max_pool2d(x, kernel_size=2)
251         x = fn.relu(x)
252
253         x = x.view(-1, 4096)
254
255         x = self.fc1(x)
256         x = fn.relu(x)
257
258         x = self.fc2(x)
259         x = fn.relu(x)
260
261         x = self.fc3(x)
262
263         return x
264
265 ######################################################################
266
267 class DeepNet3(nn.Module):
268     name = 'deepnet3'
269
270     def __init__(self):
271         super(DeepNet2, self).__init__()
272         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
273         self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
274         self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
275         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
276         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
277         self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
278         self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
279         self.fc1 = nn.Linear(4096, 512)
280         self.fc2 = nn.Linear(512, 512)
281         self.fc3 = nn.Linear(512, 2)
282
283     def forward(self, x):
284         x = self.conv1(x)
285         x = fn.max_pool2d(x, kernel_size=2)
286         x = fn.relu(x)
287
288         x = self.conv2(x)
289         x = fn.max_pool2d(x, kernel_size=2)
290         x = fn.relu(x)
291
292         x = self.conv3(x)
293         x = fn.relu(x)
294
295         x = self.conv4(x)
296         x = fn.relu(x)
297
298         x = self.conv5(x)
299         x = fn.max_pool2d(x, kernel_size=2)
300         x = fn.relu(x)
301
302         x = self.conv6(x)
303         x = fn.relu(x)
304
305         x = self.conv7(x)
306         x = fn.relu(x)
307
308         x = x.view(-1, 4096)
309
310         x = self.fc1(x)
311         x = fn.relu(x)
312
313         x = self.fc2(x)
314         x = fn.relu(x)
315
316         x = self.fc3(x)
317
318         return x
319
320 ######################################################################
321
322 def nb_errors(model, data_set):
323     ne = 0
324     for b in range(0, data_set.nb_batches):
325         input, target = data_set.get_batch(b)
326         output = model.forward(Variable(input))
327         wta_prediction = output.data.max(1)[1].view(-1)
328
329         for i in range(0, data_set.batch_size):
330             if wta_prediction[i] != target[i]:
331                 ne = ne + 1
332
333     return ne
334
335 ######################################################################
336
337 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
338     batch_size = args.batch_size
339     criterion = nn.CrossEntropyLoss()
340
341     if torch.cuda.is_available():
342         criterion.cuda()
343
344     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
345
346     start_t = time.time()
347
348     for e in range(nb_epochs_done, args.nb_epochs):
349         acc_loss = 0.0
350         for b in range(0, train_set.nb_batches):
351             input, target = train_set.get_batch(b)
352             output = model.forward(Variable(input))
353             loss = criterion(output, Variable(target))
354             acc_loss = acc_loss + loss.data[0]
355             model.zero_grad()
356             loss.backward()
357             optimizer.step()
358         dt = (time.time() - start_t) / (e + 1)
359
360         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
361                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
362
363         torch.save([ model.state_dict(), e + 1 ], model_filename)
364
365         if validation_set is not None:
366             nb_validation_errors = nb_errors(model, validation_set)
367
368             log_string('validation_error {:.02f}% {:d} {:d}'.format(
369                 100 * nb_validation_errors / validation_set.nb_samples,
370                 nb_validation_errors,
371                 validation_set.nb_samples)
372             )
373
374             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
375                 log_string('below validation_error_threshold')
376                 break
377
378     return model
379
380 ######################################################################
381
382 for arg in vars(args):
383     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
384
385 ######################################################################
386
387 def int_to_suffix(n):
388     if n >= 1000000 and n%1000000 == 0:
389         return str(n//1000000) + 'M'
390     elif n >= 1000 and n%1000 == 0:
391         return str(n//1000) + 'K'
392     else:
393         return str(n)
394
395 class vignette_logger():
396     def __init__(self, delay_min = 60):
397         self.start_t = time.time()
398         self.last_t = self.start_t
399         self.delay_min = delay_min
400
401     def __call__(self, n, m):
402         t = time.time()
403         if t > self.last_t + self.delay_min:
404             dt = (t - self.start_t) / m
405             log_string('sample_generation {:d} / {:d}'.format(
406                 m,
407                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
408             )
409             self.last_t = t
410
411 def save_examplar_vignettes(data_set, nb, name):
412     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
413
414     for k in range(0, nb):
415         b = n[k] // data_set.batch_size
416         m = n[k] % data_set.batch_size
417         i, t = data_set.get_batch(b)
418         i = i[m].float()
419         i.sub_(i.min())
420         i.div_(i.max())
421         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
422         patchwork[k].copy_(i)
423
424     torchvision.utils.save_image(patchwork, name)
425
426 ######################################################################
427
428 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
429     print('The number of samples must be a multiple of the batch size.')
430     raise
431
432 log_string('############### start ###############')
433
434 if args.compress_vignettes:
435     log_string('using_compressed_vignettes')
436     VignetteSet = svrtset.CompressedVignetteSet
437 else:
438     log_string('using_uncompressed_vignettes')
439     VignetteSet = svrtset.VignetteSet
440
441 ########################################
442 model_class = None
443 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
444     if args.model == m.name:
445         model_class = m
446         break
447 if model_class is None:
448     print('Unknown model ' + args.model)
449     raise
450
451 log_string('using model class ' + m.name)
452 ########################################
453
454 for problem_number in map(int, args.problems.split(',')):
455
456     log_string('############### problem ' + str(problem_number) + ' ###############')
457
458     model = model_class()
459
460     if torch.cuda.is_available(): model.cuda()
461
462     model_filename = model.name + '_pb:' + \
463                      str(problem_number) + '_ns:' + \
464                      int_to_suffix(args.nb_train_samples) + '.state'
465
466     nb_parameters = 0
467     for p in model.parameters(): nb_parameters += p.numel()
468     log_string('nb_parameters {:d}'.format(nb_parameters))
469
470     ##################################################
471     # Tries to load the model
472
473     try:
474         model_state_dict, nb_epochs_done = torch.load(model_filename)
475         model.load_state_dict(model_state_dict)
476         log_string('loaded_model ' + model_filename)
477     except:
478         nb_epochs_done = 0
479
480
481     ##################################################
482     # Train if necessary
483
484     if nb_epochs_done < args.nb_epochs:
485
486         log_string('training_model ' + model_filename)
487
488         t = time.time()
489
490         train_set = VignetteSet(problem_number,
491                                 args.nb_train_samples, args.batch_size,
492                                 cuda = torch.cuda.is_available(),
493                                 logger = vignette_logger())
494
495         log_string('data_generation {:0.2f} samples / s'.format(
496             train_set.nb_samples / (time.time() - t))
497         )
498
499         if args.nb_exemplar_vignettes > 0:
500             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
501                                     'examplar_{:d}.png'.format(problem_number))
502
503         if args.validation_error_threshold > 0.0:
504             validation_set = VignetteSet(problem_number,
505                                          args.nb_validation_samples, args.batch_size,
506                                          cuda = torch.cuda.is_available(),
507                                          logger = vignette_logger())
508         else:
509             validation_set = None
510
511         train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
512         log_string('saved_model ' + model_filename)
513
514         nb_train_errors = nb_errors(model, train_set)
515
516         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
517             problem_number,
518             100 * nb_train_errors / train_set.nb_samples,
519             nb_train_errors,
520             train_set.nb_samples)
521         )
522
523     ##################################################
524     # Test if necessary
525
526     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
527
528         t = time.time()
529
530         test_set = VignetteSet(problem_number,
531                                args.nb_test_samples, args.batch_size,
532                                cuda = torch.cuda.is_available())
533
534         nb_test_errors = nb_errors(model, test_set)
535
536         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
537             problem_number,
538             100 * nb_test_errors / test_set.nb_samples,
539             nb_test_errors,
540             test_set.nb_samples)
541         )
542
543 ######################################################################