Added a log starting line.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 import distutils.util
29 import re
30 import signal
31
32 from colorama import Fore, Back, Style
33
34 # Pytorch
35
36 import torch
37 import torchvision
38
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
43 from torch import nn
44 from torch.nn import functional as fn
45
46 from torchvision import datasets, transforms, utils
47
48 # SVRT
49
50 import svrtset
51
52 ######################################################################
53
54 parser = argparse.ArgumentParser(
55     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56     formatter_class = argparse.ArgumentDefaultsHelpFormatter
57 )
58
59 parser.add_argument('--nb_train_samples',
60                     type = int, default = 100000)
61
62 parser.add_argument('--nb_test_samples',
63                     type = int, default = 10000)
64
65 parser.add_argument('--nb_validation_samples',
66                     type = int, default = 10000)
67
68 parser.add_argument('--validation_error_threshold',
69                     type = float, default = 0.0,
70                     help = 'Early training termination criterion')
71
72 parser.add_argument('--nb_epochs',
73                     type = int, default = 50)
74
75 parser.add_argument('--batch_size',
76                     type = int, default = 100)
77
78 parser.add_argument('--log_file',
79                     type = str, default = 'default.log')
80
81 parser.add_argument('--nb_exemplar_vignettes',
82                     type = int, default = 32)
83
84 parser.add_argument('--compress_vignettes',
85                     type = distutils.util.strtobool, default = 'True',
86                     help = 'Use lossless compression to reduce the memory footprint')
87
88 parser.add_argument('--save_test_mistakes',
89                     type = distutils.util.strtobool, default = 'False')
90
91 parser.add_argument('--model',
92                     type = str, default = 'deepnet',
93                     help = 'What model to use')
94
95 parser.add_argument('--test_loaded_models',
96                     type = distutils.util.strtobool, default = 'False',
97                     help = 'Should we compute the test errors of loaded models')
98
99 parser.add_argument('--problems',
100                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101                     help = 'What problems to process')
102
103 args = parser.parse_args()
104
105 ######################################################################
106
107 log_file = open(args.log_file, 'a')
108 log_file.write('\n')
109 log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
110 log_file.write('\n')
111
112 pred_log_t = None
113 last_tag_t = time.time()
114
115 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
116
117 # Log and prints the string, with a time stamp. Does not log the
118 # remark
119
120 def log_string(s, remark = ''):
121     global pred_log_t, last_tag_t
122
123     t = time.time()
124
125     if pred_log_t is None:
126         elapsed = 'start'
127     else:
128         elapsed = '+{:.02f}s'.format(t - pred_log_t)
129
130     pred_log_t = t
131
132     if t > last_tag_t + 3600:
133         last_tag_t = t
134         print(Fore.RED + time.ctime() + Style.RESET_ALL)
135
136     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
137     log_file.flush()
138
139     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
140           + Style.RESET_ALL
141           + ' ' \
142           + s + Fore.CYAN + remark \
143           + Style.RESET_ALL)
144
145 ######################################################################
146
147 def handler_sigint(signum, frame):
148     log_string('got sigint')
149     exit(0)
150
151 def handler_sigterm(signum, frame):
152     log_string('got sigterm')
153     exit(0)
154
155 signal.signal(signal.SIGINT, handler_sigint)
156 signal.signal(signal.SIGTERM, handler_sigterm)
157
158 ######################################################################
159
160 # Afroze's ShallowNet
161
162 #                       map size   nb. maps
163 #                     ----------------------
164 #    input                128x128    1
165 # -- conv(21x21 x 6)   -> 108x108    6
166 # -- max(2x2)          -> 54x54      6
167 # -- conv(19x19 x 16)  -> 36x36      16
168 # -- max(2x2)          -> 18x18      16
169 # -- conv(18x18 x 120) -> 1x1        120
170 # -- reshape           -> 120        1
171 # -- full(120x84)      -> 84         1
172 # -- full(84x2)        -> 2          1
173
174 class AfrozeShallowNet(nn.Module):
175     name = 'shallownet'
176
177     def __init__(self):
178         super(AfrozeShallowNet, self).__init__()
179         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
180         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
181         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
182         self.fc1 = nn.Linear(120, 84)
183         self.fc2 = nn.Linear(84, 2)
184
185     def forward(self, x):
186         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
187         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
188         x = fn.relu(self.conv3(x))
189         x = x.view(-1, 120)
190         x = fn.relu(self.fc1(x))
191         x = self.fc2(x)
192         return x
193
194 ######################################################################
195
196 # Afroze's DeepNet
197
198 class AfrozeDeepNet(nn.Module):
199
200     name = 'deepnet'
201
202     def __init__(self):
203         super(AfrozeDeepNet, self).__init__()
204         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
205         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
206         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
207         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
208         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
209         self.fc1 = nn.Linear(1536, 256)
210         self.fc2 = nn.Linear(256, 256)
211         self.fc3 = nn.Linear(256, 2)
212
213     def forward(self, x):
214         x = self.conv1(x)
215         x = fn.max_pool2d(x, kernel_size=2)
216         x = fn.relu(x)
217
218         x = self.conv2(x)
219         x = fn.max_pool2d(x, kernel_size=2)
220         x = fn.relu(x)
221
222         x = self.conv3(x)
223         x = fn.relu(x)
224
225         x = self.conv4(x)
226         x = fn.relu(x)
227
228         x = self.conv5(x)
229         x = fn.max_pool2d(x, kernel_size=2)
230         x = fn.relu(x)
231
232         x = x.view(-1, 1536)
233
234         x = self.fc1(x)
235         x = fn.relu(x)
236
237         x = self.fc2(x)
238         x = fn.relu(x)
239
240         x = self.fc3(x)
241
242         return x
243
244 ######################################################################
245
246 class DeepNet2(nn.Module):
247     name = 'deepnet2'
248
249     def __init__(self):
250         super(DeepNet2, self).__init__()
251         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
252         self.conv2 = nn.Conv2d( 32, 256, kernel_size=5, padding=2)
253         self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
254         self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
255         self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
256         self.fc1 = nn.Linear(4096, 512)
257         self.fc2 = nn.Linear(512, 512)
258         self.fc3 = nn.Linear(512, 2)
259
260     def forward(self, x):
261         x = self.conv1(x)
262         x = fn.max_pool2d(x, kernel_size=2)
263         x = fn.relu(x)
264
265         x = self.conv2(x)
266         x = fn.max_pool2d(x, kernel_size=2)
267         x = fn.relu(x)
268
269         x = self.conv3(x)
270         x = fn.relu(x)
271
272         x = self.conv4(x)
273         x = fn.relu(x)
274
275         x = self.conv5(x)
276         x = fn.max_pool2d(x, kernel_size=2)
277         x = fn.relu(x)
278
279         x = x.view(-1, 4096)
280
281         x = self.fc1(x)
282         x = fn.relu(x)
283
284         x = self.fc2(x)
285         x = fn.relu(x)
286
287         x = self.fc3(x)
288
289         return x
290
291 ######################################################################
292
293 class DeepNet3(nn.Module):
294     name = 'deepnet3'
295
296     def __init__(self):
297         super(DeepNet3, self).__init__()
298         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
299         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
300         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
302         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
303         self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
304         self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
305         self.fc1 = nn.Linear(2048, 256)
306         self.fc2 = nn.Linear(256, 256)
307         self.fc3 = nn.Linear(256, 2)
308
309     def forward(self, x):
310         x = self.conv1(x)
311         x = fn.max_pool2d(x, kernel_size=2)
312         x = fn.relu(x)
313
314         x = self.conv2(x)
315         x = fn.max_pool2d(x, kernel_size=2)
316         x = fn.relu(x)
317
318         x = self.conv3(x)
319         x = fn.relu(x)
320
321         x = self.conv4(x)
322         x = fn.relu(x)
323
324         x = self.conv5(x)
325         x = fn.max_pool2d(x, kernel_size=2)
326         x = fn.relu(x)
327
328         x = self.conv6(x)
329         x = fn.relu(x)
330
331         x = self.conv7(x)
332         x = fn.relu(x)
333
334         x = x.view(-1, 2048)
335
336         x = self.fc1(x)
337         x = fn.relu(x)
338
339         x = self.fc2(x)
340         x = fn.relu(x)
341
342         x = self.fc3(x)
343
344         return x
345
346 ######################################################################
347
348 def nb_errors(model, data_set, mistake_filename_pattern = None):
349     ne = 0
350     for b in range(0, data_set.nb_batches):
351         input, target = data_set.get_batch(b)
352         output = model.forward(Variable(input))
353         wta_prediction = output.data.max(1)[1].view(-1)
354
355         for i in range(0, data_set.batch_size):
356             if wta_prediction[i] != target[i]:
357                 ne = ne + 1
358                 if mistake_filename_pattern is not None:
359                     img = input[i].clone()
360                     img.sub_(img.min())
361                     img.div_(img.max())
362                     k = b * data_set.batch_size + i
363                     filename = mistake_filename_pattern.format(k, target[i])
364                     torchvision.utils.save_image(img, filename)
365                     print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
366     return ne
367
368 ######################################################################
369
370 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
371     batch_size = args.batch_size
372     criterion = nn.CrossEntropyLoss()
373
374     if torch.cuda.is_available():
375         criterion.cuda()
376
377     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
378
379     start_t = time.time()
380
381     for e in range(nb_epochs_done, args.nb_epochs):
382         acc_loss = 0.0
383         for b in range(0, train_set.nb_batches):
384             input, target = train_set.get_batch(b)
385             output = model.forward(Variable(input))
386             loss = criterion(output, Variable(target))
387             acc_loss = acc_loss + loss.data[0]
388             model.zero_grad()
389             loss.backward()
390             optimizer.step()
391         dt = (time.time() - start_t) / (e + 1)
392
393         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
394                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
395
396         torch.save([ model.state_dict(), e + 1 ], model_filename)
397
398         if validation_set is not None:
399             nb_validation_errors = nb_errors(model, validation_set)
400
401             log_string('validation_error {:.02f}% {:d} {:d}'.format(
402                 100 * nb_validation_errors / validation_set.nb_samples,
403                 nb_validation_errors,
404                 validation_set.nb_samples)
405             )
406
407             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
408                 log_string('below validation_error_threshold')
409                 break
410
411     return model
412
413 ######################################################################
414
415 for arg in vars(args):
416     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
417
418 ######################################################################
419
420 def int_to_suffix(n):
421     if n >= 1000000 and n%1000000 == 0:
422         return str(n//1000000) + 'M'
423     elif n >= 1000 and n%1000 == 0:
424         return str(n//1000) + 'K'
425     else:
426         return str(n)
427
428 class vignette_logger():
429     def __init__(self, delay_min = 60):
430         self.start_t = time.time()
431         self.last_t = self.start_t
432         self.delay_min = delay_min
433
434     def __call__(self, n, m):
435         t = time.time()
436         if t > self.last_t + self.delay_min:
437             dt = (t - self.start_t) / m
438             log_string('sample_generation {:d} / {:d}'.format(
439                 m,
440                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
441             )
442             self.last_t = t
443
444 def save_examplar_vignettes(data_set, nb, name):
445     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
446
447     for k in range(0, nb):
448         b = n[k] // data_set.batch_size
449         m = n[k] % data_set.batch_size
450         i, t = data_set.get_batch(b)
451         i = i[m].float()
452         i.sub_(i.min())
453         i.div_(i.max())
454         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
455         patchwork[k].copy_(i)
456
457     torchvision.utils.save_image(patchwork, name)
458
459 ######################################################################
460
461 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
462     print('The number of samples must be a multiple of the batch size.')
463     raise
464
465 if args.compress_vignettes:
466     log_string('using_compressed_vignettes')
467     VignetteSet = svrtset.CompressedVignetteSet
468 else:
469     log_string('using_uncompressed_vignettes')
470     VignetteSet = svrtset.VignetteSet
471
472 ########################################
473 model_class = None
474 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
475     if args.model == m.name:
476         model_class = m
477         break
478 if model_class is None:
479     print('Unknown model ' + args.model)
480     raise
481
482 log_string('using model class ' + m.name)
483 ########################################
484
485 for problem_number in map(int, args.problems.split(',')):
486
487     log_string('############### problem ' + str(problem_number) + ' ###############')
488
489     model = model_class()
490
491     if torch.cuda.is_available(): model.cuda()
492
493     model_filename = model.name + '_pb:' + \
494                      str(problem_number) + '_ns:' + \
495                      int_to_suffix(args.nb_train_samples) + '.state'
496
497     nb_parameters = 0
498     for p in model.parameters(): nb_parameters += p.numel()
499     log_string('nb_parameters {:d}'.format(nb_parameters))
500
501     ##################################################
502     # Tries to load the model
503
504     try:
505         model_state_dict, nb_epochs_done = torch.load(model_filename)
506         model.load_state_dict(model_state_dict)
507         log_string('loaded_model ' + model_filename)
508     except:
509         nb_epochs_done = 0
510
511
512     ##################################################
513     # Train if necessary
514
515     if nb_epochs_done < args.nb_epochs:
516
517         log_string('training_model ' + model_filename)
518
519         t = time.time()
520
521         train_set = VignetteSet(problem_number,
522                                 args.nb_train_samples, args.batch_size,
523                                 cuda = torch.cuda.is_available(),
524                                 logger = vignette_logger())
525
526         log_string('data_generation {:0.2f} samples / s'.format(
527             train_set.nb_samples / (time.time() - t))
528         )
529
530         if args.nb_exemplar_vignettes > 0:
531             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
532                                     'examplar_{:d}.png'.format(problem_number))
533
534         if args.validation_error_threshold > 0.0:
535             validation_set = VignetteSet(problem_number,
536                                          args.nb_validation_samples, args.batch_size,
537                                          cuda = torch.cuda.is_available(),
538                                          logger = vignette_logger())
539         else:
540             validation_set = None
541
542         train_model(model, model_filename, train_set, validation_set, nb_epochs_done = nb_epochs_done)
543         log_string('saved_model ' + model_filename)
544
545         nb_train_errors = nb_errors(model, train_set)
546
547         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
548             problem_number,
549             100 * nb_train_errors / train_set.nb_samples,
550             nb_train_errors,
551             train_set.nb_samples)
552         )
553
554     ##################################################
555     # Test if necessary
556
557     if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
558
559         t = time.time()
560
561         test_set = VignetteSet(problem_number,
562                                args.nb_test_samples, args.batch_size,
563                                cuda = torch.cuda.is_available())
564
565         nb_test_errors = nb_errors(model, test_set,
566                                    mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
567
568         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
569             problem_number,
570             100 * nb_test_errors / test_set.nb_samples,
571             nb_test_errors,
572             test_set.nb_samples)
573         )
574
575 ######################################################################