Update.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 import distutils.util
29 import re
30 import signal
31
32 from colorama import Fore, Back, Style
33
34 # Pytorch
35
36 import torch
37 import torchvision
38
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
43 from torch import nn
44 from torch.nn import functional as fn
45
46 from torchvision import datasets, transforms, utils
47
48 # SVRT
49
50 import svrtset
51
52 ######################################################################
53
54 parser = argparse.ArgumentParser(
55     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56     formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 )
58
59 parser.add_argument('--nb_train_samples',
60                     type = int, default = 100000)
61
62 parser.add_argument('--nb_test_samples',
63                     type = int, default = 10000)
64
65 parser.add_argument('--nb_validation_samples',
66                     type = int, default = 10000)
67
68 parser.add_argument('--validation_error_threshold',
69                     type = float, default = 0.0,
70                     help = 'Early training termination criterion')
71
72 parser.add_argument('--nb_epochs',
73                     type = int, default = 50)
74
75 parser.add_argument('--batch_size',
76                     type = int, default = 100)
77
78 parser.add_argument('--log_file',
79                     type = str, default = 'default.log')
80
81 parser.add_argument('--nb_exemplar_vignettes',
82                     type = int, default = 32)
83
84 parser.add_argument('--compress_vignettes',
85                     type = distutils.util.strtobool, default = 'True',
86                     help = 'Use lossless compression to reduce the memory footprint')
87
88 parser.add_argument('--save_test_mistakes',
89                     type = distutils.util.strtobool, default = 'False')
90
91 parser.add_argument('--model',
92                     type = str, default = 'deepnet',
93                     help = 'What model to use')
94
95 parser.add_argument('--test_loaded_models',
96                     type = distutils.util.strtobool, default = 'False',
97                     help = 'Should we compute the test errors of loaded models')
98
99 parser.add_argument('--problems',
100                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101                     help = 'What problems to process')
102
103 args = parser.parse_args()
104
105 ######################################################################
106
107 log_file = open(args.log_file, 'a')
108 pred_log_t = None
109 last_tag_t = time.time()
110
111 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
112
113 # Log and prints the string, with a time stamp. Does not log the
114 # remark
115
116 def log_string(s, remark = ''):
117     global pred_log_t, last_tag_t
118
119     t = time.time()
120
121     if pred_log_t is None:
122         elapsed = 'start'
123     else:
124         elapsed = '+{:.02f}s'.format(t - pred_log_t)
125
126     pred_log_t = t
127
128     if t > last_tag_t + 3600:
129         last_tag_t = t
130         print(Fore.RED + time.ctime() + Style.RESET_ALL)
131
132     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
133     log_file.flush()
134
135     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
136           + Style.RESET_ALL
137           + ' ' \
138           + s + Fore.CYAN + remark \
139           + Style.RESET_ALL)
140
141 ######################################################################
142
143 def handler_sigint(signum, frame):
144     log_string('got sigint')
145     exit(0)
146
147 def handler_sigterm(signum, frame):
148     log_string('got sigterm')
149     exit(0)
150
151 signal.signal(signal.SIGINT, handler_sigint)
152 signal.signal(signal.SIGTERM, handler_sigterm)
153
154 ######################################################################
155
156 # Afroze's ShallowNet
157
158 #                       map size   nb. maps
159 #                     ----------------------
160 #    input                128x128    1
161 # -- conv(21x21 x 6)   -> 108x108    6
162 # -- max(2x2)          -> 54x54      6
163 # -- conv(19x19 x 16)  -> 36x36      16
164 # -- max(2x2)          -> 18x18      16
165 # -- conv(18x18 x 120) -> 1x1        120
166 # -- reshape           -> 120        1
167 # -- full(120x84)      -> 84         1
168 # -- full(84x2)        -> 2          1
169
170 class AfrozeShallowNet(nn.Module):
171     name = 'shallownet'
172
173     def __init__(self):
174         super(AfrozeShallowNet, self).__init__()
175         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
176         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
177         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
178         self.fc1 = nn.Linear(120, 84)
179         self.fc2 = nn.Linear(84, 2)
180
181     def forward(self, x):
182         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
183         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
184         x = fn.relu(self.conv3(x))
185         x = x.view(-1, 120)
186         x = fn.relu(self.fc1(x))
187         x = self.fc2(x)
188         return x
189
190 ######################################################################
191
192 # Afroze's DeepNet
193
194 class AfrozeDeepNet(nn.Module):
195
196     name = 'deepnet'
197
198     def __init__(self):
199         super(AfrozeDeepNet, self).__init__()
200         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
201         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
202         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
203         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
204         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
205         self.fc1 = nn.Linear(1536, 256)
206         self.fc2 = nn.Linear(256, 256)
207         self.fc3 = nn.Linear(256, 2)
208
209     def forward(self, x):
210         x = self.conv1(x)
211         x = fn.max_pool2d(x, kernel_size=2)
212         x = fn.relu(x)
213
214         x = self.conv2(x)
215         x = fn.max_pool2d(x, kernel_size=2)
216         x = fn.relu(x)
217
218         x = self.conv3(x)
219         x = fn.relu(x)
220
221         x = self.conv4(x)
222         x = fn.relu(x)
223
224         x = self.conv5(x)
225         x = fn.max_pool2d(x, kernel_size=2)
226         x = fn.relu(x)
227
228         x = x.view(-1, 1536)
229
230         x = self.fc1(x)
231         x = fn.relu(x)
232
233         x = self.fc2(x)
234         x = fn.relu(x)
235
236         x = self.fc3(x)
237
238         return x
239
240 ######################################################################
241
242 class DeepNet2(nn.Module):
243     name = 'deepnet2'
244
245     def __init__(self):
246         super(DeepNet2, self).__init__()
247         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
248         self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
249         self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
250         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
251         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
252         self.fc1 = nn.Linear(4096, 512)
253         self.fc2 = nn.Linear(512, 512)
254         self.fc3 = nn.Linear(512, 2)
255
256     def forward(self, x):
257         x = self.conv1(x)
258         x = fn.max_pool2d(x, kernel_size=2)
259         x = fn.relu(x)
260
261         x = self.conv2(x)
262         x = fn.max_pool2d(x, kernel_size=2)
263         x = fn.relu(x)
264
265         x = self.conv3(x)
266         x = fn.relu(x)
267
268         x = self.conv4(x)
269         x = fn.relu(x)
270
271         x = self.conv5(x)
272         x = fn.max_pool2d(x, kernel_size=2)
273         x = fn.relu(x)
274
275         x = x.view(-1, 4096)
276
277         x = self.fc1(x)
278         x = fn.relu(x)
279
280         x = self.fc2(x)
281         x = fn.relu(x)
282
283         x = self.fc3(x)
284
285         return x
286
287 ######################################################################
288
289 class DeepNet3(nn.Module):
290     name = 'deepnet3'
291
292     def __init__(self):
293         super(DeepNet3, self).__init__()
294         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
295         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
296         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
297         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
298         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
299         self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
300         self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301         self.fc1 = nn.Linear(2048, 256)
302         self.fc2 = nn.Linear(256, 256)
303         self.fc3 = nn.Linear(256, 2)
304
305     def forward(self, x):
306         x = self.conv1(x)
307         x = fn.max_pool2d(x, kernel_size=2)
308         x = fn.relu(x)
309
310         x = self.conv2(x)
311         x = fn.max_pool2d(x, kernel_size=2)
312         x = fn.relu(x)
313
314         x = self.conv3(x)
315         x = fn.relu(x)
316
317         x = self.conv4(x)
318         x = fn.relu(x)
319
320         x = self.conv5(x)
321         x = fn.max_pool2d(x, kernel_size=2)
322         x = fn.relu(x)
323
324         x = self.conv6(x)
325         x = fn.relu(x)
326
327         x = self.conv7(x)
328         x = fn.relu(x)
329
330         x = x.view(-1, 2048)
331
332         x = self.fc1(x)
333         x = fn.relu(x)
334
335         x = self.fc2(x)
336         x = fn.relu(x)
337
338         x = self.fc3(x)
339
340         return x
341
342 ######################################################################
343
344 def nb_errors(model, data_set, mistake_filename_pattern = None):
345     ne = 0
346     for b in range(0, data_set.nb_batches):
347         input, target = data_set.get_batch(b)
348         output = model.forward(Variable(input))
349         wta_prediction = output.data.max(1)[1].view(-1)
350
351         for i in range(0, data_set.batch_size):
352             if wta_prediction[i] != target[i]:
353                 ne = ne + 1
354                 if mistake_filename_pattern is not None:
355                     img = input[i].clone()
356                     img.sub_(img.min())
357                     img.div_(img.max())
358                     filename = mistake_filename_pattern.format(b + i, target[i])
359                     torchvision.utils.save_image(img, filename)
360                     print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
361     return ne
362
363 ######################################################################
364
365 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
366     batch_size = args.batch_size
367     criterion = nn.CrossEntropyLoss()
368
369     if torch.cuda.is_available():
370         criterion.cuda()
371
372     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
373
374     start_t = time.time()
375
376     for e in range(nb_epochs_done, args.nb_epochs):
377         acc_loss = 0.0
378         for b in range(0, train_set.nb_batches):
379             input, target = train_set.get_batch(b)
380             output = model.forward(Variable(input))
381             loss = criterion(output, Variable(target))
382             acc_loss = acc_loss + loss.data[0]
383             model.zero_grad()
384             loss.backward()
385             optimizer.step()
386         dt = (time.time() - start_t) / (e + 1)
387
388         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
389                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
390
391         torch.save([ model.state_dict(), e + 1 ], model_filename)
392
393         if validation_set is not None:
394             nb_validation_errors = nb_errors(model, validation_set)
395
396             log_string('validation_error {:.02f}% {:d} {:d}'.format(
397                 100 * nb_validation_errors / validation_set.nb_samples,
398                 nb_validation_errors,
399                 validation_set.nb_samples)
400             )
401
402             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
403                 log_string('below validation_error_threshold')
404                 break
405
406     return model
407
408 ######################################################################
409
410 for arg in vars(args):
411     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
412
413 ######################################################################
414
415 def int_to_suffix(n):
416     if n >= 1000000 and n%1000000 == 0:
417         return str(n//1000000) + 'M'
418     elif n >= 1000 and n%1000 == 0:
419         return str(n//1000) + 'K'
420     else:
421         return str(n)
422
423 class vignette_logger():
424     def __init__(self, delay_min = 60):
425         self.start_t = time.time()
426         self.last_t = self.start_t
427         self.delay_min = delay_min
428
429     def __call__(self, n, m):
430         t = time.time()
431         if t > self.last_t + self.delay_min:
432             dt = (t - self.start_t) / m
433             log_string('sample_generation {:d} / {:d}'.format(
434                 m,
435                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
436             )
437             self.last_t = t
438
439 def save_examplar_vignettes(data_set, nb, name):
440     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
441
442     for k in range(0, nb):
443         b = n[k] // data_set.batch_size
444         m = n[k] % data_set.batch_size
445         i, t = data_set.get_batch(b)
446         i = i[m].float()
447         i.sub_(i.min())
448         i.div_(i.max())
449         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
450         patchwork[k].copy_(i)
451
452     torchvision.utils.save_image(patchwork, name)
453
454 ######################################################################
455
456 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
457     print('The number of samples must be a multiple of the batch size.')
458     raise
459
460 log_string('############### start ###############')
461
462 if args.compress_vignettes:
463     log_string('using_compressed_vignettes')
464     VignetteSet = svrtset.CompressedVignetteSet
465 else:
466     log_string('using_uncompressed_vignettes')
467     VignetteSet = svrtset.VignetteSet
468
469 ########################################
470 model_class = None
471 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
472     if args.model == m.name:
473         model_class = m
474         break
475 if model_class is None:
476     print('Unknown model ' + args.model)
477     raise
478
479 log_string('using model class ' + m.name)
480 ########################################
481
482 for problem_number in map(int, args.problems.split(',')):
483
484     log_string('############### problem ' + str(problem_number) + ' ###############')
485
486     model = model_class()
487
488     if torch.cuda.is_available(): model.cuda()
489
490     model_filename = model.name + '_pb:' + \
491                      str(problem_number) + '_ns:' + \
492                      int_to_suffix(args.nb_train_samples) + '.state'
493
494     nb_parameters = 0
495     for p in model.parameters(): nb_parameters += p.numel()
496     log_string('nb_parameters {:d}'.format(nb_parameters))
497
498     ##################################################
499     # Tries to load the model
500
501     try:
502         model_state_dict, nb_epochs_done = torch.load(model_filename)
503         model.load_state_dict(model_state_dict)
504         log_string('loaded_model ' + model_filename)
505     except:
506         nb_epochs_done = 0
507
508
509     ##################################################
510     # Train if necessary
511
512     if nb_epochs_done < args.nb_epochs:
513
514         log_string('training_model ' + model_filename)
515
516         t = time.time()
517
518         train_set = VignetteSet(problem_number,
519                                 args.nb_train_samples, args.batch_size,
520                                 cuda = torch.cuda.is_available(),
521                                 logger = vignette_logger())
522
523         log_string('data_generation {:0.2f} samples / s'.format(
524             train_set.nb_samples / (time.time() - t))
525         )
526
527         if args.nb_exemplar_vignettes > 0:
528             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
529                                     'examplar_{:d}.png'.format(problem_number))
530
531         if args.validation_error_threshold > 0.0:
532             validation_set = VignetteSet(problem_number,
533                                          args.nb_validation_samples, args.batch_size,
534                                          cuda = torch.cuda.is_available(),
535                                          logger = vignette_logger())
536         else:
537             validation_set = None
538
539         train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
540         log_string('saved_model ' + model_filename)
541
542         nb_train_errors = nb_errors(model, train_set)
543
544         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
545             problem_number,
546             100 * nb_train_errors / train_set.nb_samples,
547             nb_train_errors,
548             train_set.nb_samples)
549         )
550
551     ##################################################
552     # Test if necessary
553
554     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
555
556         t = time.time()
557
558         test_set = VignetteSet(problem_number,
559                                args.nb_test_samples, args.batch_size,
560                                cuda = torch.cuda.is_available())
561
562         nb_test_errors = nb_errors(model, test_set,
563                                    mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
564
565         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
566             problem_number,
567             100 * nb_test_errors / test_set.nb_samples,
568             nb_test_errors,
569             test_set.nb_samples)
570         )
571
572 ######################################################################