Added a logger with ETA for the sample generation.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28
29 from colorama import Fore, Back, Style
30
31 # Pytorch
32
33 import torch
34
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
38 from torch import nn
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
41
42 # SVRT
43
44 import svrtset
45
46 ######################################################################
47
48 parser = argparse.ArgumentParser(
49     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50     formatter_class = argparse.ArgumentDefaultsHelpFormatter
51 )
52
53 parser.add_argument('--nb_train_samples',
54                     type = int, default = 100000)
55
56 parser.add_argument('--nb_test_samples',
57                     type = int, default = 10000)
58
59 parser.add_argument('--nb_epochs',
60                     type = int, default = 50)
61
62 parser.add_argument('--batch_size',
63                     type = int, default = 100)
64
65 parser.add_argument('--log_file',
66                     type = str, default = 'default.log')
67
68 parser.add_argument('--compress_vignettes',
69                     type = distutils.util.strtobool, default = 'True',
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 parser.add_argument('--deep_model',
73                     type = distutils.util.strtobool, default = 'True',
74                     help = 'Use Afroze\'s Alexnet-like deep model')
75
76 parser.add_argument('--test_loaded_models',
77                     type = distutils.util.strtobool, default = 'False',
78                     help = 'Should we compute the test errors of loaded models')
79
80 args = parser.parse_args()
81
82 ######################################################################
83
84 log_file = open(args.log_file, 'w')
85 pred_log_t = None
86
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88
89 # Log and prints the string, with a time stamp. Does not log the
90 # remark
91 def log_string(s, remark = ''):
92     global pred_log_t
93
94     t = time.time()
95
96     if pred_log_t is None:
97         elapsed = 'start'
98     else:
99         elapsed = '+{:.02f}s'.format(t - pred_log_t)
100
101     pred_log_t = t
102
103     log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
104     log_file.flush()
105
106     print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
107
108 ######################################################################
109
110 # Afroze's ShallowNet
111
112 #                       map size   nb. maps
113 #                     ----------------------
114 #    input                128x128    1
115 # -- conv(21x21 x 6)   -> 108x108    6
116 # -- max(2x2)          -> 54x54      6
117 # -- conv(19x19 x 16)  -> 36x36      16
118 # -- max(2x2)          -> 18x18      16
119 # -- conv(18x18 x 120) -> 1x1        120
120 # -- reshape           -> 120        1
121 # -- full(120x84)      -> 84         1
122 # -- full(84x2)        -> 2          1
123
124 class AfrozeShallowNet(nn.Module):
125     def __init__(self):
126         super(AfrozeShallowNet, self).__init__()
127         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130         self.fc1 = nn.Linear(120, 84)
131         self.fc2 = nn.Linear(84, 2)
132         self.name = 'shallownet'
133
134     def forward(self, x):
135         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137         x = fn.relu(self.conv3(x))
138         x = x.view(-1, 120)
139         x = fn.relu(self.fc1(x))
140         x = self.fc2(x)
141         return x
142
143 ######################################################################
144
145 # Afroze's DeepNet
146
147 class AfrozeDeepNet(nn.Module):
148     def __init__(self):
149         super(AfrozeDeepNet, self).__init__()
150         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
151         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
152         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
153         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
154         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
155         self.fc1 = nn.Linear(1536, 256)
156         self.fc2 = nn.Linear(256, 256)
157         self.fc3 = nn.Linear(256, 2)
158         self.name = 'deepnet'
159
160     def forward(self, x):
161         x = self.conv1(x)
162         x = fn.max_pool2d(x, kernel_size=2)
163         x = fn.relu(x)
164
165         x = self.conv2(x)
166         x = fn.max_pool2d(x, kernel_size=2)
167         x = fn.relu(x)
168
169         x = self.conv3(x)
170         x = fn.relu(x)
171
172         x = self.conv4(x)
173         x = fn.relu(x)
174
175         x = self.conv5(x)
176         x = fn.max_pool2d(x, kernel_size=2)
177         x = fn.relu(x)
178
179         x = x.view(-1, 1536)
180
181         x = self.fc1(x)
182         x = fn.relu(x)
183
184         x = self.fc2(x)
185         x = fn.relu(x)
186
187         x = self.fc3(x)
188
189         return x
190
191 ######################################################################
192
193 def train_model(model, train_set):
194     batch_size = args.batch_size
195     criterion = nn.CrossEntropyLoss()
196
197     if torch.cuda.is_available():
198         criterion.cuda()
199
200     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
201
202     start_t = time.time()
203
204     for e in range(0, args.nb_epochs):
205         acc_loss = 0.0
206         for b in range(0, train_set.nb_batches):
207             input, target = train_set.get_batch(b)
208             output = model.forward(Variable(input))
209             loss = criterion(output, Variable(target))
210             acc_loss = acc_loss + loss.data[0]
211             model.zero_grad()
212             loss.backward()
213             optimizer.step()
214         dt = (time.time() - start_t) / (e + 1)
215         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
216                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
217
218     return model
219
220 ######################################################################
221
222 def nb_errors(model, data_set):
223     ne = 0
224     for b in range(0, data_set.nb_batches):
225         input, target = data_set.get_batch(b)
226         output = model.forward(Variable(input))
227         wta_prediction = output.data.max(1)[1].view(-1)
228
229         for i in range(0, data_set.batch_size):
230             if wta_prediction[i] != target[i]:
231                 ne = ne + 1
232
233     return ne
234
235 ######################################################################
236
237 for arg in vars(args):
238     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
239
240 ######################################################################
241
242 def int_to_suffix(n):
243     if n >= 1000000 and n%1000000 == 0:
244         return str(n//1000000) + 'M'
245     elif n >= 1000 and n%1000 == 0:
246         return str(n//1000) + 'K'
247     else:
248         return str(n)
249
250 class vignette_logger():
251     def __init__(self, delay_min = 60):
252         self.start_t = time.time()
253         self.delay_min = delay_min
254
255     def __call__(self, n, m):
256         t = time.time()
257         if t > self.start_t + self.delay_min:
258             dt = (t - self.start_t) / m
259             log_string('sample_generation {:d} / {:d}'.format(
260                 m,
261                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
262             )
263             self.start_t = t
264
265 ######################################################################
266
267 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
268     print('The number of samples must be a multiple of the batch size.')
269     raise
270
271 if args.compress_vignettes:
272     log_string('using_compressed_vignettes')
273     VignetteSet = svrtset.CompressedVignetteSet
274 else:
275     log_string('using_uncompressed_vignettes')
276     VignetteSet = svrtset.VignetteSet
277
278 for problem_number in range(1, 24):
279
280     log_string('############### problem ' + str(problem_number) + ' ###############')
281
282     if args.deep_model:
283         model = AfrozeDeepNet()
284     else:
285         model = AfrozeShallowNet()
286
287     if torch.cuda.is_available(): model.cuda()
288
289     model_filename = model.name + '_pb:' + \
290                      str(problem_number) + '_ns:' + \
291                      int_to_suffix(args.nb_train_samples) + '.param'
292
293     nb_parameters = 0
294     for p in model.parameters(): nb_parameters += p.numel()
295     log_string('nb_parameters {:d}'.format(nb_parameters))
296
297     ##################################################
298     # Tries to load the model
299
300     need_to_train = False
301     try:
302         model.load_state_dict(torch.load(model_filename))
303         log_string('loaded_model ' + model_filename)
304     except:
305         need_to_train = True
306
307     ##################################################
308     # Train if necessary
309
310     if need_to_train:
311
312         log_string('training_model ' + model_filename)
313
314         t = time.time()
315
316         train_set = VignetteSet(problem_number,
317                                 args.nb_train_samples, args.batch_size,
318                                 cuda = torch.cuda.is_available(),
319                                 logger = vignette_logger())
320
321         log_string('data_generation {:0.2f} samples / s'.format(
322             train_set.nb_samples / (time.time() - t))
323         )
324
325         train_model(model, train_set)
326         torch.save(model.state_dict(), model_filename)
327         log_string('saved_model ' + model_filename)
328
329         nb_train_errors = nb_errors(model, train_set)
330
331         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
332             problem_number,
333             100 * nb_train_errors / train_set.nb_samples,
334             nb_train_errors,
335             train_set.nb_samples)
336         )
337
338     ##################################################
339     # Test if necessary
340
341     if need_to_train or args.test_loaded_models:
342
343         t = time.time()
344
345         test_set = VignetteSet(problem_number,
346                                args.nb_test_samples, args.batch_size,
347                                cuda = torch.cuda.is_available())
348
349         log_string('data_generation {:0.2f} samples / s'.format(
350             test_set.nb_samples / (time.time() - t))
351         )
352
353         nb_test_errors = nb_errors(model, test_set)
354
355         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
356             problem_number,
357             100 * nb_test_errors / test_set.nb_samples,
358             nb_test_errors,
359             test_set.nb_samples)
360         )
361
362 ######################################################################