7fe2db2d569aefc92e904ff21d11cdfec8e04d61
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28 import re
29
30 from colorama import Fore, Back, Style
31
32 # Pytorch
33
34 import torch
35 import torchvision
36
37 from torch import optim
38 from torch import multiprocessing
39 from torch import FloatTensor as Tensor
40 from torch.autograd import Variable
41 from torch import nn
42 from torch.nn import functional as fn
43
44 from torchvision import datasets, transforms, utils
45
46 # SVRT
47
48 import svrtset
49
50 ######################################################################
51
52 parser = argparse.ArgumentParser(
53     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
54     formatter_class = argparse.ArgumentDefaultsHelpFormatter
55 )
56
57 parser.add_argument('--nb_train_samples',
58                     type = int, default = 100000)
59
60 parser.add_argument('--nb_test_samples',
61                     type = int, default = 10000)
62
63 parser.add_argument('--nb_validation_samples',
64                     type = int, default = 10000)
65
66 parser.add_argument('--validation_error_threshold',
67                     type = float, default = 0.0,
68                     help = 'Early training termination criterion')
69
70 parser.add_argument('--nb_epochs',
71                     type = int, default = 50)
72
73 parser.add_argument('--batch_size',
74                     type = int, default = 100)
75
76 parser.add_argument('--log_file',
77                     type = str, default = 'default.log')
78
79 parser.add_argument('--nb_exemplar_vignettes',
80                     type = int, default = -1)
81
82 parser.add_argument('--compress_vignettes',
83                     type = distutils.util.strtobool, default = 'True',
84                     help = 'Use lossless compression to reduce the memory footprint')
85
86 parser.add_argument('--deep_model',
87                     type = distutils.util.strtobool, default = 'True',
88                     help = 'Use Afroze\'s Alexnet-like deep model')
89
90 parser.add_argument('--test_loaded_models',
91                     type = distutils.util.strtobool, default = 'False',
92                     help = 'Should we compute the test errors of loaded models')
93
94 parser.add_argument('--problems',
95                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
96                     help = 'What problems to process')
97
98 args = parser.parse_args()
99
100 ######################################################################
101
102 log_file = open(args.log_file, 'a')
103 pred_log_t = None
104 last_tag_t = time.time()
105
106 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
107
108 # Log and prints the string, with a time stamp. Does not log the
109 # remark
110
111 def log_string(s, remark = ''):
112     global pred_log_t, last_tag_t
113
114     t = time.time()
115
116     if pred_log_t is None:
117         elapsed = 'start'
118     else:
119         elapsed = '+{:.02f}s'.format(t - pred_log_t)
120
121     pred_log_t = t
122
123     if t > last_tag_t + 3600:
124         last_tag_t = t
125         print(Fore.RED + time.ctime() + Style.RESET_ALL)
126
127     log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
128     log_file.flush()
129
130     print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
131
132 ######################################################################
133
134 # Afroze's ShallowNet
135
136 #                       map size   nb. maps
137 #                     ----------------------
138 #    input                128x128    1
139 # -- conv(21x21 x 6)   -> 108x108    6
140 # -- max(2x2)          -> 54x54      6
141 # -- conv(19x19 x 16)  -> 36x36      16
142 # -- max(2x2)          -> 18x18      16
143 # -- conv(18x18 x 120) -> 1x1        120
144 # -- reshape           -> 120        1
145 # -- full(120x84)      -> 84         1
146 # -- full(84x2)        -> 2          1
147
148 class AfrozeShallowNet(nn.Module):
149     def __init__(self):
150         super(AfrozeShallowNet, self).__init__()
151         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
152         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
153         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
154         self.fc1 = nn.Linear(120, 84)
155         self.fc2 = nn.Linear(84, 2)
156         self.name = 'shallownet'
157
158     def forward(self, x):
159         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
160         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
161         x = fn.relu(self.conv3(x))
162         x = x.view(-1, 120)
163         x = fn.relu(self.fc1(x))
164         x = self.fc2(x)
165         return x
166
167 ######################################################################
168
169 # Afroze's DeepNet
170
171 class AfrozeDeepNet(nn.Module):
172     def __init__(self):
173         super(AfrozeDeepNet, self).__init__()
174         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
175         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
176         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
177         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
178         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
179         self.fc1 = nn.Linear(1536, 256)
180         self.fc2 = nn.Linear(256, 256)
181         self.fc3 = nn.Linear(256, 2)
182         self.name = 'deepnet'
183
184     def forward(self, x):
185         x = self.conv1(x)
186         x = fn.max_pool2d(x, kernel_size=2)
187         x = fn.relu(x)
188
189         x = self.conv2(x)
190         x = fn.max_pool2d(x, kernel_size=2)
191         x = fn.relu(x)
192
193         x = self.conv3(x)
194         x = fn.relu(x)
195
196         x = self.conv4(x)
197         x = fn.relu(x)
198
199         x = self.conv5(x)
200         x = fn.max_pool2d(x, kernel_size=2)
201         x = fn.relu(x)
202
203         x = x.view(-1, 1536)
204
205         x = self.fc1(x)
206         x = fn.relu(x)
207
208         x = self.fc2(x)
209         x = fn.relu(x)
210
211         x = self.fc3(x)
212
213         return x
214
215 ######################################################################
216
217 def nb_errors(model, data_set):
218     ne = 0
219     for b in range(0, data_set.nb_batches):
220         input, target = data_set.get_batch(b)
221         output = model.forward(Variable(input))
222         wta_prediction = output.data.max(1)[1].view(-1)
223
224         for i in range(0, data_set.batch_size):
225             if wta_prediction[i] != target[i]:
226                 ne = ne + 1
227
228     return ne
229
230 ######################################################################
231
232 def train_model(model, train_set, validation_set):
233     batch_size = args.batch_size
234     criterion = nn.CrossEntropyLoss()
235
236     if torch.cuda.is_available():
237         criterion.cuda()
238
239     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
240
241     start_t = time.time()
242
243     for e in range(0, args.nb_epochs):
244         acc_loss = 0.0
245         for b in range(0, train_set.nb_batches):
246             input, target = train_set.get_batch(b)
247             output = model.forward(Variable(input))
248             loss = criterion(output, Variable(target))
249             acc_loss = acc_loss + loss.data[0]
250             model.zero_grad()
251             loss.backward()
252             optimizer.step()
253         dt = (time.time() - start_t) / (e + 1)
254
255         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
256                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
257
258         if validation_set is not None:
259             nb_validation_errors = nb_errors(model, validation_set)
260
261             log_string('validation_error {:.02f}% {:d} {:d}'.format(
262                 100 * nb_validation_errors / validation_set.nb_samples,
263                 nb_validation_errors,
264                 validation_set.nb_samples)
265             )
266
267             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
268                 log_string('below validation_error_threshold')
269                 break
270
271     return model
272
273 ######################################################################
274
275 for arg in vars(args):
276     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
277
278 ######################################################################
279
280 def int_to_suffix(n):
281     if n >= 1000000 and n%1000000 == 0:
282         return str(n//1000000) + 'M'
283     elif n >= 1000 and n%1000 == 0:
284         return str(n//1000) + 'K'
285     else:
286         return str(n)
287
288 class vignette_logger():
289     def __init__(self, delay_min = 60):
290         self.start_t = time.time()
291         self.last_t = self.start_t
292         self.delay_min = delay_min
293
294     def __call__(self, n, m):
295         t = time.time()
296         if t > self.last_t + self.delay_min:
297             dt = (t - self.start_t) / m
298             log_string('sample_generation {:d} / {:d}'.format(
299                 m,
300                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
301             )
302             self.last_t = t
303
304 def save_examplar_vignettes(data_set, nb, name):
305     n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
306
307     for k in range(0, nb):
308         b = n[k] // data_set.batch_size
309         m = n[k] % data_set.batch_size
310         i, t = data_set.get_batch(b)
311         i = i[m].float()
312         i.sub_(i.min())
313         i.div_(i.max())
314         if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
315         patchwork[k].copy_(i)
316
317     torchvision.utils.save_image(patchwork, name)
318
319 ######################################################################
320
321 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
322     print('The number of samples must be a multiple of the batch size.')
323     raise
324
325 log_string('############### start ###############')
326
327 if args.compress_vignettes:
328     log_string('using_compressed_vignettes')
329     VignetteSet = svrtset.CompressedVignetteSet
330 else:
331     log_string('using_uncompressed_vignettes')
332     VignetteSet = svrtset.VignetteSet
333
334 for problem_number in map(int, args.problems.split(',')):
335
336     log_string('############### problem ' + str(problem_number) + ' ###############')
337
338     if args.deep_model:
339         model = AfrozeDeepNet()
340     else:
341         model = AfrozeShallowNet()
342
343     if torch.cuda.is_available(): model.cuda()
344
345     model_filename = model.name + '_pb:' + \
346                      str(problem_number) + '_ns:' + \
347                      int_to_suffix(args.nb_train_samples) + '.param'
348
349     nb_parameters = 0
350     for p in model.parameters(): nb_parameters += p.numel()
351     log_string('nb_parameters {:d}'.format(nb_parameters))
352
353     ##################################################
354     # Tries to load the model
355
356     need_to_train = False
357     try:
358         model.load_state_dict(torch.load(model_filename))
359         log_string('loaded_model ' + model_filename)
360     except:
361         need_to_train = True
362
363     ##################################################
364     # Train if necessary
365
366     if need_to_train:
367
368         log_string('training_model ' + model_filename)
369
370         t = time.time()
371
372         train_set = VignetteSet(problem_number,
373                                 args.nb_train_samples, args.batch_size,
374                                 cuda = torch.cuda.is_available(),
375                                 logger = vignette_logger())
376
377         log_string('data_generation {:0.2f} samples / s'.format(
378             train_set.nb_samples / (time.time() - t))
379         )
380
381         if args.nb_exemplar_vignettes > 0:
382             save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
383                                     'examplar_{:d}.png'.format(problem_number))
384
385         if args.validation_error_threshold > 0.0:
386             validation_set = VignetteSet(problem_number,
387                                          args.nb_validation_samples, args.batch_size,
388                                          cuda = torch.cuda.is_available(),
389                                          logger = vignette_logger())
390         else:
391             validation_set = None
392
393         train_model(model, train_set, validation_set)
394         torch.save(model.state_dict(), model_filename)
395         log_string('saved_model ' + model_filename)
396
397         nb_train_errors = nb_errors(model, train_set)
398
399         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
400             problem_number,
401             100 * nb_train_errors / train_set.nb_samples,
402             nb_train_errors,
403             train_set.nb_samples)
404         )
405
406     ##################################################
407     # Test if necessary
408
409     if need_to_train or args.test_loaded_models:
410
411         t = time.time()
412
413         test_set = VignetteSet(problem_number,
414                                args.nb_test_samples, args.batch_size,
415                                cuda = torch.cuda.is_available())
416
417         nb_test_errors = nb_errors(model, test_set)
418
419         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
420             problem_number,
421             100 * nb_test_errors / test_set.nb_samples,
422             nb_test_errors,
423             test_set.nb_samples)
424         )
425
426 ######################################################################