Minor bug fix.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 import distutils.util
29 import re
30 import signal
31
32 from colorama import Fore, Back, Style
33
34 # Pytorch
35
36 import torch
37 import torchvision
38
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
43 from torch import nn
44 from torch.nn import functional as fn
45
46 from torchvision import datasets, transforms, utils
47
48 # SVRT
49
50 import svrtset
51
52 ######################################################################
53
54 parser = argparse.ArgumentParser(
55     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56     formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 )
58
59 parser.add_argument('--nb_train_samples',
60                     type = int, default = 100000)
61
62 parser.add_argument('--nb_test_samples',
63                     type = int, default = 10000)
64
65 parser.add_argument('--nb_validation_samples',
66                     type = int, default = 10000)
67
68 parser.add_argument('--validation_error_threshold',
69                     type = float, default = 0.0,
70                     help = 'Early training termination criterion')
71
72 parser.add_argument('--nb_epochs',
73                     type = int, default = 50)
74
75 parser.add_argument('--batch_size',
76                     type = int, default = 100)
77
78 parser.add_argument('--log_file',
79                     type = str, default = 'default.log')
80
81 parser.add_argument('--nb_exemplar_vignettes',
82                     type = int, default = 32)
83
84 parser.add_argument('--compress_vignettes',
85                     type = distutils.util.strtobool, default = 'True',
86                     help = 'Use lossless compression to reduce the memory footprint')
87
88 parser.add_argument('--save_test_mistakes',
89                     type = distutils.util.strtobool, default = 'False')
90
91 parser.add_argument('--model',
92                     type = str, default = 'deepnet',
93                     help = 'What model to use')
94
95 parser.add_argument('--test_loaded_models',
96                     type = distutils.util.strtobool, default = 'False',
97                     help = 'Should we compute the test errors of loaded models')
98
99 parser.add_argument('--problems',
100                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101                     help = 'What problems to process')
102
103 args = parser.parse_args()
104
105 ######################################################################
106
107 log_file = open(args.log_file, 'a')
108 pred_log_t = None
109 last_tag_t = time.time()
110
111 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
112
113 # Log and prints the string, with a time stamp. Does not log the
114 # remark
115
116 def log_string(s, remark = ''):
117     global pred_log_t, last_tag_t
118
119     t = time.time()
120
121     if pred_log_t is None:
122         elapsed = 'start'
123     else:
124         elapsed = '+{:.02f}s'.format(t - pred_log_t)
125
126     pred_log_t = t
127
128     if t > last_tag_t + 3600:
129         last_tag_t = t
130         print(Fore.RED + time.ctime() + Style.RESET_ALL)
131
132     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
133     log_file.flush()
134
135     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
136           + Style.RESET_ALL
137           + ' ' \
138           + s + Fore.CYAN + remark \
139           + Style.RESET_ALL)
140
141 ######################################################################
142
143 def handler_sigint(signum, frame):
144     log_string('got sigint')
145     exit(0)
146
147 def handler_sigterm(signum, frame):
148     log_string('got sigterm')
149     exit(0)
150
151 signal.signal(signal.SIGINT, handler_sigint)
152 signal.signal(signal.SIGTERM, handler_sigterm)
153
154 ######################################################################
155
156 # Afroze's ShallowNet
157
158 #                       map size   nb. maps
159 #                     ----------------------
160 #    input                128x128    1
161 # -- conv(21x21 x 6)   -> 108x108    6
162 # -- max(2x2)          -> 54x54      6
163 # -- conv(19x19 x 16)  -> 36x36      16
164 # -- max(2x2)          -> 18x18      16
165 # -- conv(18x18 x 120) -> 1x1        120
166 # -- reshape           -> 120        1
167 # -- full(120x84)      -> 84         1
168 # -- full(84x2)        -> 2          1
169
170 class AfrozeShallowNet(nn.Module):
171     name = 'shallownet'
172
173     def __init__(self):
174         super(AfrozeShallowNet, self).__init__()
175         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
176         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
177         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
178         self.fc1 = nn.Linear(120, 84)
179         self.fc2 = nn.Linear(84, 2)
180
181     def forward(self, x):
182         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
183         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
184         x = fn.relu(self.conv3(x))
185         x = x.view(-1, 120)
186         x = fn.relu(self.fc1(x))
187         x = self.fc2(x)
188         return x
189
190 ######################################################################
191
192 # Afroze's DeepNet
193
194 class AfrozeDeepNet(nn.Module):
195
196     name = 'deepnet'
197
198     def __init__(self):
199         super(AfrozeDeepNet, self).__init__()
200         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
201         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
202         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
203         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
204         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
205         self.fc1 = nn.Linear(1536, 256)
206         self.fc2 = nn.Linear(256, 256)
207         self.fc3 = nn.Linear(256, 2)
208
209     def forward(self, x):
210         x = self.conv1(x)
211         x = fn.max_pool2d(x, kernel_size=2)
212         x = fn.relu(x)
213
214         x = self.conv2(x)
215         x = fn.max_pool2d(x, kernel_size=2)
216         x = fn.relu(x)
217
218         x = self.conv3(x)
219         x = fn.relu(x)
220
221         x = self.conv4(x)
222         x = fn.relu(x)
223
224         x = self.conv5(x)
225         x = fn.max_pool2d(x, kernel_size=2)
226         x = fn.relu(x)
227
228         x = x.view(-1, 1536)
229
230         x = self.fc1(x)
231         x = fn.relu(x)
232
233         x = self.fc2(x)
234         x = fn.relu(x)
235
236         x = self.fc3(x)
237
238         return x
239
240 ######################################################################
241
242 class DeepNet2(nn.Module):
243     name = 'deepnet2'
244
245     def __init__(self):
246         super(DeepNet2, self).__init__()
247         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
248         self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
249         self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
250         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
251         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
252         self.fc1 = nn.Linear(4096, 512)
253         self.fc2 = nn.Linear(512, 512)
254         self.fc3 = nn.Linear(512, 2)
255
256     def forward(self, x):
257         x = self.conv1(x)
258         x = fn.max_pool2d(x, kernel_size=2)
259         x = fn.relu(x)
260
261         x = self.conv2(x)
262         x = fn.max_pool2d(x, kernel_size=2)
263         x = fn.relu(x)
264
265         x = self.conv3(x)
266         x = fn.relu(x)
267
268         x = self.conv4(x)
269         x = fn.relu(x)
270
271         x = self.conv5(x)
272         x = fn.max_pool2d(x, kernel_size=2)
273         x = fn.relu(x)
274
275         x = x.view(-1, 4096)
276
277         x = self.fc1(x)
278         x = fn.relu(x)
279
280         x = self.fc2(x)
281         x = fn.relu(x)
282
283         x = self.fc3(x)
284
285         return x
286
287 ######################################################################
288
289 class DeepNet3(nn.Module):
290     name = 'deepnet3'
291
292     def __init__(self):
293         super(DeepNet3, self).__init__()
294         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
295         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
296         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
297         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
298         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
299         self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
300         self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301         self.fc1 = nn.Linear(2048, 256)
302         self.fc2 = nn.Linear(256, 256)
303         self.fc3 = nn.Linear(256, 2)
304
305     def forward(self, x):
306         x = self.conv1(x)
307         x = fn.max_pool2d(x, kernel_size=2)
308         x = fn.relu(x)
309
310         x = self.conv2(x)
311         x = fn.max_pool2d(x, kernel_size=2)
312         x = fn.relu(x)
313
314         x = self.conv3(x)
315         x = fn.relu(x)
316
317         x = self.conv4(x)
318         x = fn.relu(x)
319
320         x = self.conv5(x)
321         x = fn.max_pool2d(x, kernel_size=2)
322         x = fn.relu(x)
323
324         x = self.conv6(x)
325         x = fn.relu(x)
326
327         x = self.conv7(x)
328         x = fn.relu(x)
329
330         x = x.view(-1, 2048)
331
332         x = self.fc1(x)
333         x = fn.relu(x)
334
335         x = self.fc2(x)
336         x = fn.relu(x)
337
338         x = self.fc3(x)
339
340         return x
341
342 ######################################################################
343
344 def nb_errors(model, data_set, mistake_filename_pattern = None):
345     ne = 0
346     for b in range(0, data_set.nb_batches):
347         input, target = data_set.get_batch(b)
348         output = model.forward(Variable(input))
349         wta_prediction = output.data.max(1)[1].view(-1)
350
351         for i in range(0, data_set.batch_size):
352             if wta_prediction[i] != target[i]:
353                 ne = ne + 1
354                 if mistake_filename_pattern is not None:
355                     img = input[i].clone()
356                     img.sub_(img.min())
357                     img.div_(img.max())
358                     k = b * data_set.batch_size + i
359                     filename = mistake_filename_pattern.format(k, target[i])
360                     torchvision.utils.save_image(img, filename)
361                     print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
362     return ne
363
364 ######################################################################
365
366 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
367     batch_size = args.batch_size
368     criterion = nn.CrossEntropyLoss()
369
370     if torch.cuda.is_available():
371         criterion.cuda()
372
373     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
374
375     start_t = time.time()
376
377     for e in range(nb_epochs_done, args.nb_epochs):
378         acc_loss = 0.0
379         for b in range(0, train_set.nb_batches):
380             input, target = train_set.get_batch(b)
381             output = model.forward(Variable(input))
382             loss = criterion(output, Variable(target))
383             acc_loss = acc_loss + loss.data[0]
384             model.zero_grad()
385             loss.backward()
386             optimizer.step()
387         dt = (time.time() - start_t) / (e + 1)
388
389         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
390                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
391
392         torch.save([ model.state_dict(), e + 1 ], model_filename)
393
394         if validation_set is not None:
395             nb_validation_errors = nb_errors(model, validation_set)
396
397             log_string('validation_error {:.02f}% {:d} {:d}'.format(
398                 100 * nb_validation_errors / validation_set.nb_samples,
399                 nb_validation_errors,
400                 validation_set.nb_samples)
401             )
402
403             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
404                 log_string('below validation_error_threshold')
405                 break
406
407     return model
408
409 ######################################################################
410
411 for arg in vars(args):
412     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
413
414 ######################################################################
415
416 def int_to_suffix(n):
417     if n >= 1000000 and n%1000000 == 0:
418         return str(n//1000000) + 'M'
419     elif n >= 1000 and n%1000 == 0:
420         return str(n//1000) + 'K'
421     else:
422         return str(n)
423
424 class vignette_logger():
425     def __init__(self, delay_min = 60):
426         self.start_t = time.time()
427         self.last_t = self.start_t
428         self.delay_min = delay_min
429
430     def __call__(self, n, m):
431         t = time.time()
432         if t > self.last_t + self.delay_min:
433             dt = (t - self.start_t) / m
434             log_string('sample_generation {:d} / {:d}'.format(
435                 m,
436                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
437             )
438             self.last_t = t
439
440 def save_examplar_vignettes(data_set, nb, name):
441     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
442
443     for k in range(0, nb):
444         b = n[k] // data_set.batch_size
445         m = n[k] % data_set.batch_size
446         i, t = data_set.get_batch(b)
447         i = i[m].float()
448         i.sub_(i.min())
449         i.div_(i.max())
450         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
451         patchwork[k].copy_(i)
452
453     torchvision.utils.save_image(patchwork, name)
454
455 ######################################################################
456
457 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
458     print('The number of samples must be a multiple of the batch size.')
459     raise
460
461 log_string('############### start ###############')
462
463 if args.compress_vignettes:
464     log_string('using_compressed_vignettes')
465     VignetteSet = svrtset.CompressedVignetteSet
466 else:
467     log_string('using_uncompressed_vignettes')
468     VignetteSet = svrtset.VignetteSet
469
470 ########################################
471 model_class = None
472 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
473     if args.model == m.name:
474         model_class = m
475         break
476 if model_class is None:
477     print('Unknown model ' + args.model)
478     raise
479
480 log_string('using model class ' + m.name)
481 ########################################
482
483 for problem_number in map(int, args.problems.split(',')):
484
485     log_string('############### problem ' + str(problem_number) + ' ###############')
486
487     model = model_class()
488
489     if torch.cuda.is_available(): model.cuda()
490
491     model_filename = model.name + '_pb:' + \
492                      str(problem_number) + '_ns:' + \
493                      int_to_suffix(args.nb_train_samples) + '.state'
494
495     nb_parameters = 0
496     for p in model.parameters(): nb_parameters += p.numel()
497     log_string('nb_parameters {:d}'.format(nb_parameters))
498
499     ##################################################
500     # Tries to load the model
501
502     try:
503         model_state_dict, nb_epochs_done = torch.load(model_filename)
504         model.load_state_dict(model_state_dict)
505         log_string('loaded_model ' + model_filename)
506     except:
507         nb_epochs_done = 0
508
509
510     ##################################################
511     # Train if necessary
512
513     if nb_epochs_done < args.nb_epochs:
514
515         log_string('training_model ' + model_filename)
516
517         t = time.time()
518
519         train_set = VignetteSet(problem_number,
520                                 args.nb_train_samples, args.batch_size,
521                                 cuda = torch.cuda.is_available(),
522                                 logger = vignette_logger())
523
524         log_string('data_generation {:0.2f} samples / s'.format(
525             train_set.nb_samples / (time.time() - t))
526         )
527
528         if args.nb_exemplar_vignettes > 0:
529             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
530                                     'examplar_{:d}.png'.format(problem_number))
531
532         if args.validation_error_threshold > 0.0:
533             validation_set = VignetteSet(problem_number,
534                                          args.nb_validation_samples, args.batch_size,
535                                          cuda = torch.cuda.is_available(),
536                                          logger = vignette_logger())
537         else:
538             validation_set = None
539
540         train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
541         log_string('saved_model ' + model_filename)
542
543         nb_train_errors = nb_errors(model, train_set)
544
545         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
546             problem_number,
547             100 * nb_train_errors / train_set.nb_samples,
548             nb_train_errors,
549             train_set.nb_samples)
550         )
551
552     ##################################################
553     # Test if necessary
554
555     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
556
557         t = time.time()
558
559         test_set = VignetteSet(problem_number,
560                                args.nb_test_samples, args.batch_size,
561                                cuda = torch.cuda.is_available())
562
563         nb_test_errors = nb_errors(model, test_set,
564                                    mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
565
566         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
567             problem_number,
568             100 * nb_test_errors / test_set.nb_samples,
569             nb_test_errors,
570             test_set.nb_samples)
571         )
572
573 ######################################################################