Fix.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28
29 from colorama import Fore, Back, Style
30
31 # Pytorch
32
33 import torch
34
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
38 from torch import nn
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
41
42 # SVRT
43
44 import svrtset
45
46 ######################################################################
47
48 parser = argparse.ArgumentParser(
49     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50     formatter_class = argparse.ArgumentDefaultsHelpFormatter
51 )
52
53 parser.add_argument('--nb_train_samples',
54                     type = int, default = 100000)
55
56 parser.add_argument('--nb_test_samples',
57                     type = int, default = 10000)
58
59 parser.add_argument('--nb_validation_samples',
60                     type = int, default = 10000)
61
62 parser.add_argument('--validation_error_threshold',
63                     type = float, default = 0.0,
64                     help = 'Early training termination criterion')
65
66 parser.add_argument('--nb_epochs',
67                     type = int, default = 50)
68
69 parser.add_argument('--batch_size',
70                     type = int, default = 100)
71
72 parser.add_argument('--log_file',
73                     type = str, default = 'default.log')
74
75 parser.add_argument('--compress_vignettes',
76                     type = distutils.util.strtobool, default = 'True',
77                     help = 'Use lossless compression to reduce the memory footprint')
78
79 parser.add_argument('--deep_model',
80                     type = distutils.util.strtobool, default = 'True',
81                     help = 'Use Afroze\'s Alexnet-like deep model')
82
83 parser.add_argument('--test_loaded_models',
84                     type = distutils.util.strtobool, default = 'False',
85                     help = 'Should we compute the test errors of loaded models')
86
87 parser.add_argument('--problems',
88                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
89                     help = 'What problems to process')
90
91 args = parser.parse_args()
92
93 ######################################################################
94
95 log_file = open(args.log_file, 'a')
96 pred_log_t = None
97
98 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
99
100 # Log and prints the string, with a time stamp. Does not log the
101 # remark
102 def log_string(s, remark = ''):
103     global pred_log_t
104
105     t = time.time()
106
107     if pred_log_t is None:
108         elapsed = 'start'
109     else:
110         elapsed = '+{:.02f}s'.format(t - pred_log_t)
111
112     pred_log_t = t
113
114     log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
115     log_file.flush()
116
117     print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
118
119 ######################################################################
120
121 # Afroze's ShallowNet
122
123 #                       map size   nb. maps
124 #                     ----------------------
125 #    input                128x128    1
126 # -- conv(21x21 x 6)   -> 108x108    6
127 # -- max(2x2)          -> 54x54      6
128 # -- conv(19x19 x 16)  -> 36x36      16
129 # -- max(2x2)          -> 18x18      16
130 # -- conv(18x18 x 120) -> 1x1        120
131 # -- reshape           -> 120        1
132 # -- full(120x84)      -> 84         1
133 # -- full(84x2)        -> 2          1
134
135 class AfrozeShallowNet(nn.Module):
136     def __init__(self):
137         super(AfrozeShallowNet, self).__init__()
138         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
139         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
140         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
141         self.fc1 = nn.Linear(120, 84)
142         self.fc2 = nn.Linear(84, 2)
143         self.name = 'shallownet'
144
145     def forward(self, x):
146         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
147         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
148         x = fn.relu(self.conv3(x))
149         x = x.view(-1, 120)
150         x = fn.relu(self.fc1(x))
151         x = self.fc2(x)
152         return x
153
154 ######################################################################
155
156 # Afroze's DeepNet
157
158 class AfrozeDeepNet(nn.Module):
159     def __init__(self):
160         super(AfrozeDeepNet, self).__init__()
161         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
162         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
163         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
164         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
165         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
166         self.fc1 = nn.Linear(1536, 256)
167         self.fc2 = nn.Linear(256, 256)
168         self.fc3 = nn.Linear(256, 2)
169         self.name = 'deepnet'
170
171     def forward(self, x):
172         x = self.conv1(x)
173         x = fn.max_pool2d(x, kernel_size=2)
174         x = fn.relu(x)
175
176         x = self.conv2(x)
177         x = fn.max_pool2d(x, kernel_size=2)
178         x = fn.relu(x)
179
180         x = self.conv3(x)
181         x = fn.relu(x)
182
183         x = self.conv4(x)
184         x = fn.relu(x)
185
186         x = self.conv5(x)
187         x = fn.max_pool2d(x, kernel_size=2)
188         x = fn.relu(x)
189
190         x = x.view(-1, 1536)
191
192         x = self.fc1(x)
193         x = fn.relu(x)
194
195         x = self.fc2(x)
196         x = fn.relu(x)
197
198         x = self.fc3(x)
199
200         return x
201
202 ######################################################################
203
204 def nb_errors(model, data_set):
205     ne = 0
206     for b in range(0, data_set.nb_batches):
207         input, target = data_set.get_batch(b)
208         output = model.forward(Variable(input))
209         wta_prediction = output.data.max(1)[1].view(-1)
210
211         for i in range(0, data_set.batch_size):
212             if wta_prediction[i] != target[i]:
213                 ne = ne + 1
214
215     return ne
216
217 ######################################################################
218
219 def train_model(model, train_set, validation_set):
220     batch_size = args.batch_size
221     criterion = nn.CrossEntropyLoss()
222
223     if torch.cuda.is_available():
224         criterion.cuda()
225
226     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
227
228     start_t = time.time()
229
230     for e in range(0, args.nb_epochs):
231         acc_loss = 0.0
232         for b in range(0, train_set.nb_batches):
233             input, target = train_set.get_batch(b)
234             output = model.forward(Variable(input))
235             loss = criterion(output, Variable(target))
236             acc_loss = acc_loss + loss.data[0]
237             model.zero_grad()
238             loss.backward()
239             optimizer.step()
240         dt = (time.time() - start_t) / (e + 1)
241
242         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
243                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
244
245         if validation_set is not None:
246             nb_validation_errors = nb_errors(model, validation_set)
247
248             log_string('validation_error {:.02f}% {:d} {:d}'.format(
249                 100 * nb_validation_errors / validation_set.nb_samples,
250                 nb_validation_errors,
251                 validation_set.nb_samples)
252             )
253
254             if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
255                 log_string('below validation_error_threshold')
256                 break
257
258     return model
259
260 ######################################################################
261
262 for arg in vars(args):
263     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
264
265 ######################################################################
266
267 def int_to_suffix(n):
268     if n >= 1000000 and n%1000000 == 0:
269         return str(n//1000000) + 'M'
270     elif n >= 1000 and n%1000 == 0:
271         return str(n//1000) + 'K'
272     else:
273         return str(n)
274
275 class vignette_logger():
276     def __init__(self, delay_min = 60):
277         self.start_t = time.time()
278         self.last_t = self.start_t
279         self.delay_min = delay_min
280
281     def __call__(self, n, m):
282         t = time.time()
283         if t > self.last_t + self.delay_min:
284             dt = (t - self.start_t) / m
285             log_string('sample_generation {:d} / {:d}'.format(
286                 m,
287                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
288             )
289             self.last_t = t
290
291 ######################################################################
292
293 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
294     print('The number of samples must be a multiple of the batch size.')
295     raise
296
297 log_string('############### start ###############')
298
299 if args.compress_vignettes:
300     log_string('using_compressed_vignettes')
301     VignetteSet = svrtset.CompressedVignetteSet
302 else:
303     log_string('using_uncompressed_vignettes')
304     VignetteSet = svrtset.VignetteSet
305
306 for problem_number in map(int, args.problems.split(',')):
307
308     log_string('############### problem ' + str(problem_number) + ' ###############')
309
310     if args.deep_model:
311         model = AfrozeDeepNet()
312     else:
313         model = AfrozeShallowNet()
314
315     if torch.cuda.is_available(): model.cuda()
316
317     model_filename = model.name + '_pb:' + \
318                      str(problem_number) + '_ns:' + \
319                      int_to_suffix(args.nb_train_samples) + '.param'
320
321     nb_parameters = 0
322     for p in model.parameters(): nb_parameters += p.numel()
323     log_string('nb_parameters {:d}'.format(nb_parameters))
324
325     ##################################################
326     # Tries to load the model
327
328     need_to_train = False
329     try:
330         model.load_state_dict(torch.load(model_filename))
331         log_string('loaded_model ' + model_filename)
332     except:
333         need_to_train = True
334
335     ##################################################
336     # Train if necessary
337
338     if need_to_train:
339
340         log_string('training_model ' + model_filename)
341
342         t = time.time()
343
344         train_set = VignetteSet(problem_number,
345                                 args.nb_train_samples, args.batch_size,
346                                 cuda = torch.cuda.is_available(),
347                                 logger = vignette_logger())
348
349         log_string('data_generation {:0.2f} samples / s'.format(
350             train_set.nb_samples / (time.time() - t))
351         )
352
353         if args.validation_error_threshold > 0.0:
354             validation_set = VignetteSet(problem_number,
355                                          args.nb_validation_samples, args.batch_size,
356                                          cuda = torch.cuda.is_available(),
357                                          logger = vignette_logger())
358         else:
359             validation_set = None
360
361         train_model(model, train_set, validation_set)
362         torch.save(model.state_dict(), model_filename)
363         log_string('saved_model ' + model_filename)
364
365         nb_train_errors = nb_errors(model, train_set)
366
367         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
368             problem_number,
369             100 * nb_train_errors / train_set.nb_samples,
370             nb_train_errors,
371             train_set.nb_samples)
372         )
373
374     ##################################################
375     # Test if necessary
376
377     if need_to_train or args.test_loaded_models:
378
379         t = time.time()
380
381         test_set = VignetteSet(problem_number,
382                                args.nb_test_samples, args.batch_size,
383                                cuda = torch.cuda.is_available())
384
385         nb_test_errors = nb_errors(model, test_set)
386
387         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
388             problem_number,
389             100 * nb_test_errors / test_set.nb_samples,
390             nb_test_errors,
391             test_set.nb_samples)
392         )
393
394 ######################################################################