Re-oups.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28 import re
29
30 from colorama import Fore, Back, Style
31
32 # Pytorch
33
34 import torch
35 import torchvision
36
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
41 from torch import nn
42 from torch.nn import functional as fn
43
44 from torchvision import datasets, transforms, utils
45
46 # SVRT
47
48 import svrtset
49
50 ######################################################################
51
52 parser = argparse.ArgumentParser(
53     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54     formatter_class = argparse.ArgumentDefaultsHelpFormatter
55 )
56
57 parser.add_argument('--nb_train_samples',
58                     type = int, default = 100000)
59
60 parser.add_argument('--nb_test_samples',
61                     type = int, default = 10000)
62
63 parser.add_argument('--nb_validation_samples',
64                     type = int, default = 10000)
65
66 parser.add_argument('--validation_error_threshold',
67                     type = float, default = 0.0,
68                     help = 'Early training termination criterion')
69
70 parser.add_argument('--nb_epochs',
71                     type = int, default = 50)
72
73 parser.add_argument('--batch_size',
74                     type = int, default = 100)
75
76 parser.add_argument('--log_file',
77                     type = str, default = 'default.log')
78
79 parser.add_argument('--nb_exemplar_vignettes',
80                     type = int, default = 32)
81
82 parser.add_argument('--compress_vignettes',
83                     type = distutils.util.strtobool, default = 'True',
84                     help = 'Use lossless compression to reduce the memory footprint')
85
86 parser.add_argument('--model',
87                     type = str, default = 'deepnet',
88                     help = 'What model to use')
89
90 parser.add_argument('--test_loaded_models',
91                     type = distutils.util.strtobool, default = 'False',
92                     help = 'Should we compute the test errors of loaded models')
93
94 parser.add_argument('--problems',
95                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96                     help = 'What problems to process')
97
98 args = parser.parse_args()
99
100 ######################################################################
101
102 log_file = open(args.log_file, 'a')
103 pred_log_t = None
104 last_tag_t = time.time()
105
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
107
108 # Log and prints the string, with a time stamp. Does not log the
109 # remark
110
111 def log_string(s, remark = ''):
112     global pred_log_t, last_tag_t
113
114     t = time.time()
115
116     if pred_log_t is None:
117         elapsed = 'start'
118     else:
119         elapsed = '+{:.02f}s'.format(t - pred_log_t)
120
121     pred_log_t = t
122
123     if t > last_tag_t + 3600:
124         last_tag_t = t
125         print(Fore.RED + time.ctime() + Style.RESET_ALL)
126
127     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
128     log_file.flush()
129
130     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
131
132 ######################################################################
133
134 # Afroze's ShallowNet
135
136 #                       map size   nb. maps
137 #                     ----------------------
138 #    input                128x128    1
139 # -- conv(21x21 x 6)   -> 108x108    6
140 # -- max(2x2)          -> 54x54      6
141 # -- conv(19x19 x 16)  -> 36x36      16
142 # -- max(2x2)          -> 18x18      16
143 # -- conv(18x18 x 120) -> 1x1        120
144 # -- reshape           -> 120        1
145 # -- full(120x84)      -> 84         1
146 # -- full(84x2)        -> 2          1
147
148 class AfrozeShallowNet(nn.Module):
149     name = 'shallownet'
150
151     def __init__(self):
152         super(AfrozeShallowNet, self).__init__()
153         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
154         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
155         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
156         self.fc1 = nn.Linear(120, 84)
157         self.fc2 = nn.Linear(84, 2)
158
159     def forward(self, x):
160         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
161         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
162         x = fn.relu(self.conv3(x))
163         x = x.view(-1, 120)
164         x = fn.relu(self.fc1(x))
165         x = self.fc2(x)
166         return x
167
168 ######################################################################
169
170 # Afroze's DeepNet
171
172 class AfrozeDeepNet(nn.Module):
173
174     name = 'deepnet'
175
176     def __init__(self):
177         super(AfrozeDeepNet, self).__init__()
178         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
179         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
180         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
181         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
182         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
183         self.fc1 = nn.Linear(1536, 256)
184         self.fc2 = nn.Linear(256, 256)
185         self.fc3 = nn.Linear(256, 2)
186
187     def forward(self, x):
188         x = self.conv1(x)
189         x = fn.max_pool2d(x, kernel_size=2)
190         x = fn.relu(x)
191
192         x = self.conv2(x)
193         x = fn.max_pool2d(x, kernel_size=2)
194         x = fn.relu(x)
195
196         x = self.conv3(x)
197         x = fn.relu(x)
198
199         x = self.conv4(x)
200         x = fn.relu(x)
201
202         x = self.conv5(x)
203         x = fn.max_pool2d(x, kernel_size=2)
204         x = fn.relu(x)
205
206         x = x.view(-1, 1536)
207
208         x = self.fc1(x)
209         x = fn.relu(x)
210
211         x = self.fc2(x)
212         x = fn.relu(x)
213
214         x = self.fc3(x)
215
216         return x
217
218 ######################################################################
219
220 class DeepNet2(nn.Module):
221     name = 'deepnet2'
222
223     def __init__(self):
224         super(DeepNet2, self).__init__()
225         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
226         self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
227         self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
228         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
229         self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
230         self.fc1 = nn.Linear(2048, 512)
231         self.fc2 = nn.Linear(512, 512)
232         self.fc3 = nn.Linear(512, 2)
233
234     def forward(self, x):
235         x = self.conv1(x)
236         x = fn.max_pool2d(x, kernel_size=2)
237         x = fn.relu(x)
238
239         x = self.conv2(x)
240         x = fn.max_pool2d(x, kernel_size=2)
241         x = fn.relu(x)
242
243         x = self.conv3(x)
244         x = fn.relu(x)
245
246         x = self.conv4(x)
247         x = fn.relu(x)
248
249         x = self.conv5(x)
250         x = fn.max_pool2d(x, kernel_size=2)
251         x = fn.relu(x)
252
253         x = x.view(-1, 2048)
254
255         x = self.fc1(x)
256         x = fn.relu(x)
257
258         x = self.fc2(x)
259         x = fn.relu(x)
260
261         x = self.fc3(x)
262
263         return x
264
265 ######################################################################
266
267 def nb_errors(model, data_set):
268     ne = 0
269     for b in range(0, data_set.nb_batches):
270         input, target = data_set.get_batch(b)
271         output = model.forward(Variable(input))
272         wta_prediction = output.data.max(1)[1].view(-1)
273
274         for i in range(0, data_set.batch_size):
275             if wta_prediction[i] != target[i]:
276                 ne = ne + 1
277
278     return ne
279
280 ######################################################################
281
282 def train_model(model, train_set, validation_set):
283     batch_size = args.batch_size
284     criterion = nn.CrossEntropyLoss()
285
286     if torch.cuda.is_available():
287         criterion.cuda()
288
289     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
290
291     start_t = time.time()
292
293     for e in range(0, args.nb_epochs):
294         acc_loss = 0.0
295         for b in range(0, train_set.nb_batches):
296             input, target = train_set.get_batch(b)
297             output = model.forward(Variable(input))
298             loss = criterion(output, Variable(target))
299             acc_loss = acc_loss + loss.data[0]
300             model.zero_grad()
301             loss.backward()
302             optimizer.step()
303         dt = (time.time() - start_t) / (e + 1)
304
305         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
306                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
307
308         if validation_set is not None:
309             nb_validation_errors = nb_errors(model, validation_set)
310
311             log_string('validation_error {:.02f}% {:d} {:d}'.format(
312                 100 * nb_validation_errors / validation_set.nb_samples,
313                 nb_validation_errors,
314                 validation_set.nb_samples)
315             )
316
317             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
318                 log_string('below validation_error_threshold')
319                 break
320
321     return model
322
323 ######################################################################
324
325 for arg in vars(args):
326     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
327
328 ######################################################################
329
330 def int_to_suffix(n):
331     if n >= 1000000 and n%1000000 == 0:
332         return str(n//1000000) + 'M'
333     elif n >= 1000 and n%1000 == 0:
334         return str(n//1000) + 'K'
335     else:
336         return str(n)
337
338 class vignette_logger():
339     def __init__(self, delay_min = 60):
340         self.start_t = time.time()
341         self.last_t = self.start_t
342         self.delay_min = delay_min
343
344     def __call__(self, n, m):
345         t = time.time()
346         if t > self.last_t + self.delay_min:
347             dt = (t - self.start_t) / m
348             log_string('sample_generation {:d} / {:d}'.format(
349                 m,
350                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
351             )
352             self.last_t = t
353
354 def save_examplar_vignettes(data_set, nb, name):
355     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
356
357     for k in range(0, nb):
358         b = n[k] // data_set.batch_size
359         m = n[k] % data_set.batch_size
360         i, t = data_set.get_batch(b)
361         i = i[m].float()
362         i.sub_(i.min())
363         i.div_(i.max())
364         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
365         patchwork[k].copy_(i)
366
367     torchvision.utils.save_image(patchwork, name)
368
369 ######################################################################
370
371 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
372     print('The number of samples must be a multiple of the batch size.')
373     raise
374
375 log_string('############### start ###############')
376
377 if args.compress_vignettes:
378     log_string('using_compressed_vignettes')
379     VignetteSet = svrtset.CompressedVignetteSet
380 else:
381     log_string('using_uncompressed_vignettes')
382     VignetteSet = svrtset.VignetteSet
383
384 ########################################
385 model_class = None
386 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]:
387     if args.model == m.name:
388         model_class = m
389         break
390 if model_class is None:
391     print('Unknown model ' + args.model)
392     raise
393
394 log_string('using model class ' + m.name)
395 ########################################
396
397 for problem_number in map(int, args.problems.split(',')):
398
399     log_string('############### problem ' + str(problem_number) + ' ###############')
400
401     model = model_class()
402
403     if torch.cuda.is_available(): model.cuda()
404
405     model_filename = model.name + '_pb:' + \
406                      str(problem_number) + '_ns:' + \
407                      int_to_suffix(args.nb_train_samples) + '.param'
408
409     nb_parameters = 0
410     for p in model.parameters(): nb_parameters += p.numel()
411     log_string('nb_parameters {:d}'.format(nb_parameters))
412
413     ##################################################
414     # Tries to load the model
415
416     need_to_train = False
417     try:
418         model.load_state_dict(torch.load(model_filename))
419         log_string('loaded_model ' + model_filename)
420     except:
421         need_to_train = True
422
423     ##################################################
424     # Train if necessary
425
426     if need_to_train:
427
428         log_string('training_model ' + model_filename)
429
430         t = time.time()
431
432         train_set = VignetteSet(problem_number,
433                                 args.nb_train_samples, args.batch_size,
434                                 cuda = torch.cuda.is_available(),
435                                 logger = vignette_logger())
436
437         log_string('data_generation {:0.2f} samples / s'.format(
438             train_set.nb_samples / (time.time() - t))
439         )
440
441         if args.nb_exemplar_vignettes > 0:
442             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
443                                     'examplar_{:d}.png'.format(problem_number))
444
445         if args.validation_error_threshold > 0.0:
446             validation_set = VignetteSet(problem_number,
447                                          args.nb_validation_samples, args.batch_size,
448                                          cuda = torch.cuda.is_available(),
449                                          logger = vignette_logger())
450         else:
451             validation_set = None
452
453         train_model(model, train_set, validation_set)
454         torch.save(model.state_dict(), model_filename)
455         log_string('saved_model ' + model_filename)
456
457         nb_train_errors = nb_errors(model, train_set)
458
459         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
460             problem_number,
461             100 * nb_train_errors / train_set.nb_samples,
462             nb_train_errors,
463             train_set.nb_samples)
464         )
465
466     ##################################################
467     # Test if necessary
468
469     if need_to_train or args.test_loaded_models:
470
471         t = time.time()
472
473         test_set = VignetteSet(problem_number,
474                                args.nb_test_samples, args.batch_size,
475                                cuda = torch.cuda.is_available())
476
477         nb_test_errors = nb_errors(model, test_set)
478
479         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
480             problem_number,
481             100 * nb_test_errors / test_set.nb_samples,
482             nb_test_errors,
483             test_set.nb_samples)
484         )
485
486 ######################################################################