Minor update.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 import distutils.util
29 import re
30 import signal
31
32 from colorama import Fore, Back, Style
33
34 # Pytorch
35
36 import torch
37 import torchvision
38
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
43 from torch import nn
44 from torch.nn import functional as fn
45
46 from torchvision import datasets, transforms, utils
47
48 # SVRT
49
50 import svrtset
51
52 ######################################################################
53
54 parser = argparse.ArgumentParser(
55     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56     formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 )
58
59 parser.add_argument('--nb_train_samples',
60                     type = int, default = 100000)
61
62 parser.add_argument('--nb_test_samples',
63                     type = int, default = 10000)
64
65 parser.add_argument('--nb_validation_samples',
66                     type = int, default = 10000)
67
68 parser.add_argument('--validation_error_threshold',
69                     type = float, default = 0.0,
70                     help = 'Early training termination criterion')
71
72 parser.add_argument('--nb_epochs',
73                     type = int, default = 50)
74
75 parser.add_argument('--batch_size',
76                     type = int, default = 100)
77
78 parser.add_argument('--log_file',
79                     type = str, default = 'default.log')
80
81 parser.add_argument('--nb_exemplar_vignettes',
82                     type = int, default = 32)
83
84 parser.add_argument('--compress_vignettes',
85                     type = distutils.util.strtobool, default = 'True',
86                     help = 'Use lossless compression to reduce the memory footprint')
87
88 parser.add_argument('--save_test_mistakes',
89                     type = distutils.util.strtobool, default = 'False')
90
91 parser.add_argument('--model',
92                     type = str, default = 'deepnet',
93                     help = 'What model to use')
94
95 parser.add_argument('--test_loaded_models',
96                     type = distutils.util.strtobool, default = 'False',
97                     help = 'Should we compute the test errors of loaded models')
98
99 parser.add_argument('--problems',
100                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101                     help = 'What problems to process')
102
103 args = parser.parse_args()
104
105 ######################################################################
106
107 log_file = open(args.log_file, 'a')
108 log_file.write('\n')
109 log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
110 log_file.write('\n')
111
112 pred_log_t = None
113 last_tag_t = time.time()
114
115 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
116
117 # Log and prints the string, with a time stamp. Does not log the
118 # remark
119
120 def log_string(s, remark = ''):
121     global pred_log_t, last_tag_t
122
123     t = time.time()
124
125     if pred_log_t is None:
126         elapsed = 'start'
127     else:
128         elapsed = '+{:.02f}s'.format(t - pred_log_t)
129
130     pred_log_t = t
131
132     if t > last_tag_t + 3600:
133         last_tag_t = t
134         print(Fore.RED + time.ctime() + Style.RESET_ALL)
135
136     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
137     log_file.flush()
138
139     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
140           + Style.RESET_ALL
141           + ' ' \
142           + s + Fore.CYAN + remark \
143           + Style.RESET_ALL)
144
145 ######################################################################
146
147 def handler_sigint(signum, frame):
148     log_string('got sigint')
149     exit(0)
150
151 def handler_sigterm(signum, frame):
152     log_string('got sigterm')
153     exit(0)
154
155 signal.signal(signal.SIGINT, handler_sigint)
156 signal.signal(signal.SIGTERM, handler_sigterm)
157
158 ######################################################################
159
160 # Afroze's ShallowNet
161
162 #                       map size   nb. maps
163 #                     ----------------------
164 #    input                128x128    1
165 # -- conv(21x21 x 6)   -> 108x108    6
166 # -- max(2x2)          -> 54x54      6
167 # -- conv(19x19 x 16)  -> 36x36      16
168 # -- max(2x2)          -> 18x18      16
169 # -- conv(18x18 x 120) -> 1x1        120
170 # -- reshape           -> 120        1
171 # -- full(120x84)      -> 84         1
172 # -- full(84x2)        -> 2          1
173
174 class AfrozeShallowNet(nn.Module):
175     name = 'shallownet'
176
177     def __init__(self):
178         super(AfrozeShallowNet, self).__init__()
179         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
180         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
181         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
182         self.fc1 = nn.Linear(120, 84)
183         self.fc2 = nn.Linear(84, 2)
184
185     def forward(self, x):
186         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
187         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
188         x = fn.relu(self.conv3(x))
189         x = x.view(-1, 120)
190         x = fn.relu(self.fc1(x))
191         x = self.fc2(x)
192         return x
193
194 ######################################################################
195
196 # Afroze's DeepNet
197
198 class AfrozeDeepNet(nn.Module):
199
200     name = 'deepnet'
201
202     def __init__(self):
203         super(AfrozeDeepNet, self).__init__()
204         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
205         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
206         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
207         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
208         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
209         self.fc1 = nn.Linear(1536, 256)
210         self.fc2 = nn.Linear(256, 256)
211         self.fc3 = nn.Linear(256, 2)
212
213     def forward(self, x):
214         x = self.conv1(x)
215         x = fn.max_pool2d(x, kernel_size=2)
216         x = fn.relu(x)
217
218         x = self.conv2(x)
219         x = fn.max_pool2d(x, kernel_size=2)
220         x = fn.relu(x)
221
222         x = self.conv3(x)
223         x = fn.relu(x)
224
225         x = self.conv4(x)
226         x = fn.relu(x)
227
228         x = self.conv5(x)
229         x = fn.max_pool2d(x, kernel_size=2)
230         x = fn.relu(x)
231
232         x = x.view(-1, 1536)
233
234         x = self.fc1(x)
235         x = fn.relu(x)
236
237         x = self.fc2(x)
238         x = fn.relu(x)
239
240         x = self.fc3(x)
241
242         return x
243
244 ######################################################################
245
246 class DeepNet2(nn.Module):
247     name = 'deepnet2'
248
249     def __init__(self):
250         super(DeepNet2, self).__init__()
251         self.nb_channels = 512
252         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
253         self.conv2 = nn.Conv2d( 32, self.nb_channels, kernel_size=5, padding=2)
254         self.conv3 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
255         self.conv4 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
256         self.conv5 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
257         self.fc1 = nn.Linear(16 * self.nb_channels, 512)
258         self.fc2 = nn.Linear(512, 512)
259         self.fc3 = nn.Linear(512, 2)
260
261     def forward(self, x):
262         x = self.conv1(x)
263         x = fn.max_pool2d(x, kernel_size=2)
264         x = fn.relu(x)
265
266         x = self.conv2(x)
267         x = fn.max_pool2d(x, kernel_size=2)
268         x = fn.relu(x)
269
270         x = self.conv3(x)
271         x = fn.relu(x)
272
273         x = self.conv4(x)
274         x = fn.relu(x)
275
276         x = self.conv5(x)
277         x = fn.max_pool2d(x, kernel_size=2)
278         x = fn.relu(x)
279
280         x = x.view(-1, 16 * self.nb_channels)
281
282         x = self.fc1(x)
283         x = fn.relu(x)
284
285         x = self.fc2(x)
286         x = fn.relu(x)
287
288         x = self.fc3(x)
289
290         return x
291
292 ######################################################################
293
294 class DeepNet3(nn.Module):
295     name = 'deepnet3'
296
297     def __init__(self):
298         super(DeepNet3, self).__init__()
299         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
300         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
301         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
302         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
303         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
304         self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
305         self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
306         self.fc1 = nn.Linear(2048, 256)
307         self.fc2 = nn.Linear(256, 256)
308         self.fc3 = nn.Linear(256, 2)
309
310     def forward(self, x):
311         x = self.conv1(x)
312         x = fn.max_pool2d(x, kernel_size=2)
313         x = fn.relu(x)
314
315         x = self.conv2(x)
316         x = fn.max_pool2d(x, kernel_size=2)
317         x = fn.relu(x)
318
319         x = self.conv3(x)
320         x = fn.relu(x)
321
322         x = self.conv4(x)
323         x = fn.relu(x)
324
325         x = self.conv5(x)
326         x = fn.max_pool2d(x, kernel_size=2)
327         x = fn.relu(x)
328
329         x = self.conv6(x)
330         x = fn.relu(x)
331
332         x = self.conv7(x)
333         x = fn.relu(x)
334
335         x = x.view(-1, 2048)
336
337         x = self.fc1(x)
338         x = fn.relu(x)
339
340         x = self.fc2(x)
341         x = fn.relu(x)
342
343         x = self.fc3(x)
344
345         return x
346
347 ######################################################################
348
349 def nb_errors(model, data_set, mistake_filename_pattern = None):
350     ne = 0
351     for b in range(0, data_set.nb_batches):
352         input, target = data_set.get_batch(b)
353         output = model.forward(Variable(input))
354         wta_prediction = output.data.max(1)[1].view(-1)
355
356         for i in range(0, data_set.batch_size):
357             if wta_prediction[i] != target[i]:
358                 ne = ne + 1
359                 if mistake_filename_pattern is not None:
360                     img = input[i].clone()
361                     img.sub_(img.min())
362                     img.div_(img.max())
363                     k = b * data_set.batch_size + i
364                     filename = mistake_filename_pattern.format(k, target[i])
365                     torchvision.utils.save_image(img, filename)
366                     print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
367     return ne
368
369 ######################################################################
370
371 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
372     batch_size = args.batch_size
373     criterion = nn.CrossEntropyLoss()
374
375     if torch.cuda.is_available():
376         criterion.cuda()
377
378     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
379
380     start_t = time.time()
381
382     for e in range(nb_epochs_done, args.nb_epochs):
383         acc_loss = 0.0
384         for b in range(0, train_set.nb_batches):
385             input, target = train_set.get_batch(b)
386             output = model.forward(Variable(input))
387             loss = criterion(output, Variable(target))
388             acc_loss = acc_loss + loss.data[0]
389             model.zero_grad()
390             loss.backward()
391             optimizer.step()
392         dt = (time.time() - start_t) / (e + 1)
393
394         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
395                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
396
397         torch.save([ model.state_dict(), e + 1 ], model_filename)
398
399         if validation_set is not None:
400             nb_validation_errors = nb_errors(model, validation_set)
401
402             log_string('validation_error {:.02f}% {:d} {:d}'.format(
403                 100 * nb_validation_errors / validation_set.nb_samples,
404                 nb_validation_errors,
405                 validation_set.nb_samples)
406             )
407
408             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
409                 log_string('below validation_error_threshold')
410                 break
411
412     return model
413
414 ######################################################################
415
416 for arg in vars(args):
417     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
418
419 ######################################################################
420
421 def int_to_suffix(n):
422     if n >= 1000000 and n%1000000 == 0:
423         return str(n//1000000) + 'M'
424     elif n >= 1000 and n%1000 == 0:
425         return str(n//1000) + 'K'
426     else:
427         return str(n)
428
429 class vignette_logger():
430     def __init__(self, delay_min = 60):
431         self.start_t = time.time()
432         self.last_t = self.start_t
433         self.delay_min = delay_min
434
435     def __call__(self, n, m):
436         t = time.time()
437         if t > self.last_t + self.delay_min:
438             dt = (t - self.start_t) / m
439             log_string('sample_generation {:d} / {:d}'.format(
440                 m,
441                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
442             )
443             self.last_t = t
444
445 def save_exemplar_vignettes(data_set, nb, name):
446     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
447
448     for k in range(0, nb):
449         b = n[k] // data_set.batch_size
450         m = n[k] % data_set.batch_size
451         i, t = data_set.get_batch(b)
452         i = i[m].float()
453         i.sub_(i.min())
454         i.div_(i.max())
455         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
456         patchwork[k].copy_(i)
457
458     torchvision.utils.save_image(patchwork, name)
459
460 ######################################################################
461
462 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
463     print('The number of samples must be a multiple of the batch size.')
464     raise
465
466 if args.compress_vignettes:
467     log_string('using_compressed_vignettes')
468     VignetteSet = svrtset.CompressedVignetteSet
469 else:
470     log_string('using_uncompressed_vignettes')
471     VignetteSet = svrtset.VignetteSet
472
473 ########################################
474 model_class = None
475 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
476     if args.model == m.name:
477         model_class = m
478         break
479 if model_class is None:
480     print('Unknown model ' + args.model)
481     raise
482
483 log_string('using model class ' + m.name)
484 ########################################
485
486 for problem_number in map(int, args.problems.split(',')):
487
488     log_string('############### problem ' + str(problem_number) + ' ###############')
489
490     model = model_class()
491
492     if torch.cuda.is_available(): model.cuda()
493
494     model_filename = model.name + '_pb:' + \
495                      str(problem_number) + '_ns:' + \
496                      int_to_suffix(args.nb_train_samples) + '.pth'
497
498     nb_parameters = 0
499     for p in model.parameters(): nb_parameters += p.numel()
500     log_string('nb_parameters {:d}'.format(nb_parameters))
501
502     ##################################################
503     # Tries to load the model
504
505     try:
506         model_state_dict, nb_epochs_done = torch.load(model_filename)
507         model.load_state_dict(model_state_dict)
508         log_string('loaded_model ' + model_filename)
509     except:
510         nb_epochs_done = 0
511
512
513     ##################################################
514     # Train if necessary
515
516     if nb_epochs_done < args.nb_epochs:
517
518         log_string('training_model ' + model_filename)
519
520         t = time.time()
521
522         train_set = VignetteSet(problem_number,
523                                 args.nb_train_samples, args.batch_size,
524                                 cuda = torch.cuda.is_available(),
525                                 logger = vignette_logger())
526
527         log_string('data_generation {:0.2f} samples / s'.format(
528             train_set.nb_samples / (time.time() - t))
529         )
530
531         if args.nb_exemplar_vignettes > 0:
532             save_exemplar_vignettes(train_set, args.nb_exemplar_vignettes,
533                                     'exemplar_{:d}.png'.format(problem_number))
534
535         if args.validation_error_threshold > 0.0:
536             validation_set = VignetteSet(problem_number,
537                                          args.nb_validation_samples, args.batch_size,
538                                          cuda = torch.cuda.is_available(),
539                                          logger = vignette_logger())
540         else:
541             validation_set = None
542
543         train_model(model, model_filename,
544                     train_set, validation_set,
545                     nb_epochs_done = nb_epochs_done)
546
547         log_string('saved_model ' + model_filename)
548
549         nb_train_errors = nb_errors(model, train_set)
550
551         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
552             problem_number,
553             100 * nb_train_errors / train_set.nb_samples,
554             nb_train_errors,
555             train_set.nb_samples)
556         )
557
558     ##################################################
559     # Test if necessary
560
561     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
562
563         t = time.time()
564
565         test_set = VignetteSet(problem_number,
566                                args.nb_test_samples, args.batch_size,
567                                cuda = torch.cuda.is_available())
568
569         nb_test_errors = nb_errors(model, test_set,
570                                    mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
571
572         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
573             problem_number,
574             100 * nb_test_errors / test_set.nb_samples,
575             nb_test_errors,
576             test_set.nb_samples)
577         )
578
579 ######################################################################