Append to the log instead of overwriting it.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27 import distutils.util
28
29 from colorama import Fore, Back, Style
30
31 # Pytorch
32
33 import torch
34
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
38 from torch import nn
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
41
42 # SVRT
43
44 import svrtset
45
46 ######################################################################
47
48 parser = argparse.ArgumentParser(
49     description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50     formatter_class = argparse.ArgumentDefaultsHelpFormatter
51 )
52
53 parser.add_argument('--nb_train_samples',
54                     type = int, default = 100000)
55
56 parser.add_argument('--nb_test_samples',
57                     type = int, default = 10000)
58
59 parser.add_argument('--nb_epochs',
60                     type = int, default = 50)
61
62 parser.add_argument('--batch_size',
63                     type = int, default = 100)
64
65 parser.add_argument('--log_file',
66                     type = str, default = 'default.log')
67
68 parser.add_argument('--compress_vignettes',
69                     type = distutils.util.strtobool, default = 'True',
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 parser.add_argument('--deep_model',
73                     type = distutils.util.strtobool, default = 'True',
74                     help = 'Use Afroze\'s Alexnet-like deep model')
75
76 parser.add_argument('--test_loaded_models',
77                     type = distutils.util.strtobool, default = 'False',
78                     help = 'Should we compute the test errors of loaded models')
79
80 parser.add_argument('--problems',
81                     type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
82                     help = 'What problems to process')
83
84 args = parser.parse_args()
85
86 ######################################################################
87
88 log_file = open(args.log_file, 'a')
89 pred_log_t = None
90
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
92
93 # Log and prints the string, with a time stamp. Does not log the
94 # remark
95 def log_string(s, remark = ''):
96     global pred_log_t
97
98     t = time.time()
99
100     if pred_log_t is None:
101         elapsed = 'start'
102     else:
103         elapsed = '+{:.02f}s'.format(t - pred_log_t)
104
105     pred_log_t = t
106
107     log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
108     log_file.flush()
109
110     print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
111
112 ######################################################################
113
114 # Afroze's ShallowNet
115
116 #                       map size   nb. maps
117 #                     ----------------------
118 #    input                128x128    1
119 # -- conv(21x21 x 6)   -> 108x108    6
120 # -- max(2x2)          -> 54x54      6
121 # -- conv(19x19 x 16)  -> 36x36      16
122 # -- max(2x2)          -> 18x18      16
123 # -- conv(18x18 x 120) -> 1x1        120
124 # -- reshape           -> 120        1
125 # -- full(120x84)      -> 84         1
126 # -- full(84x2)        -> 2          1
127
128 class AfrozeShallowNet(nn.Module):
129     def __init__(self):
130         super(AfrozeShallowNet, self).__init__()
131         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134         self.fc1 = nn.Linear(120, 84)
135         self.fc2 = nn.Linear(84, 2)
136         self.name = 'shallownet'
137
138     def forward(self, x):
139         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141         x = fn.relu(self.conv3(x))
142         x = x.view(-1, 120)
143         x = fn.relu(self.fc1(x))
144         x = self.fc2(x)
145         return x
146
147 ######################################################################
148
149 # Afroze's DeepNet
150
151 class AfrozeDeepNet(nn.Module):
152     def __init__(self):
153         super(AfrozeDeepNet, self).__init__()
154         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
155         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
156         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
157         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
158         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
159         self.fc1 = nn.Linear(1536, 256)
160         self.fc2 = nn.Linear(256, 256)
161         self.fc3 = nn.Linear(256, 2)
162         self.name = 'deepnet'
163
164     def forward(self, x):
165         x = self.conv1(x)
166         x = fn.max_pool2d(x, kernel_size=2)
167         x = fn.relu(x)
168
169         x = self.conv2(x)
170         x = fn.max_pool2d(x, kernel_size=2)
171         x = fn.relu(x)
172
173         x = self.conv3(x)
174         x = fn.relu(x)
175
176         x = self.conv4(x)
177         x = fn.relu(x)
178
179         x = self.conv5(x)
180         x = fn.max_pool2d(x, kernel_size=2)
181         x = fn.relu(x)
182
183         x = x.view(-1, 1536)
184
185         x = self.fc1(x)
186         x = fn.relu(x)
187
188         x = self.fc2(x)
189         x = fn.relu(x)
190
191         x = self.fc3(x)
192
193         return x
194
195 ######################################################################
196
197 def train_model(model, train_set):
198     batch_size = args.batch_size
199     criterion = nn.CrossEntropyLoss()
200
201     if torch.cuda.is_available():
202         criterion.cuda()
203
204     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
205
206     start_t = time.time()
207
208     for e in range(0, args.nb_epochs):
209         acc_loss = 0.0
210         for b in range(0, train_set.nb_batches):
211             input, target = train_set.get_batch(b)
212             output = model.forward(Variable(input))
213             loss = criterion(output, Variable(target))
214             acc_loss = acc_loss + loss.data[0]
215             model.zero_grad()
216             loss.backward()
217             optimizer.step()
218         dt = (time.time() - start_t) / (e + 1)
219         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
220                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
221
222     return model
223
224 ######################################################################
225
226 def nb_errors(model, data_set):
227     ne = 0
228     for b in range(0, data_set.nb_batches):
229         input, target = data_set.get_batch(b)
230         output = model.forward(Variable(input))
231         wta_prediction = output.data.max(1)[1].view(-1)
232
233         for i in range(0, data_set.batch_size):
234             if wta_prediction[i] != target[i]:
235                 ne = ne + 1
236
237     return ne
238
239 ######################################################################
240
241 for arg in vars(args):
242     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
243
244 ######################################################################
245
246 def int_to_suffix(n):
247     if n >= 1000000 and n%1000000 == 0:
248         return str(n//1000000) + 'M'
249     elif n >= 1000 and n%1000 == 0:
250         return str(n//1000) + 'K'
251     else:
252         return str(n)
253
254 class vignette_logger():
255     def __init__(self, delay_min = 60):
256         self.start_t = time.time()
257         self.last_t = self.start_t
258         self.delay_min = delay_min
259
260     def __call__(self, n, m):
261         t = time.time()
262         if t > self.last_t + self.delay_min:
263             dt = (t - self.start_t) / m
264             log_string('sample_generation {:d} / {:d}'.format(
265                 m,
266                 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
267             )
268             self.last_t = t
269
270 ######################################################################
271
272 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
273     print('The number of samples must be a multiple of the batch size.')
274     raise
275
276 log_string('############### start ###############')
277
278 if args.compress_vignettes:
279     log_string('using_compressed_vignettes')
280     VignetteSet = svrtset.CompressedVignetteSet
281 else:
282     log_string('using_uncompressed_vignettes')
283     VignetteSet = svrtset.VignetteSet
284
285 for problem_number in map(int, args.problems.split(',')):
286
287     log_string('############### problem ' + str(problem_number) + ' ###############')
288
289     if args.deep_model:
290         model = AfrozeDeepNet()
291     else:
292         model = AfrozeShallowNet()
293
294     if torch.cuda.is_available(): model.cuda()
295
296     model_filename = model.name + '_pb:' + \
297                      str(problem_number) + '_ns:' + \
298                      int_to_suffix(args.nb_train_samples) + '.param'
299
300     nb_parameters = 0
301     for p in model.parameters(): nb_parameters += p.numel()
302     log_string('nb_parameters {:d}'.format(nb_parameters))
303
304     ##################################################
305     # Tries to load the model
306
307     need_to_train = False
308     try:
309         model.load_state_dict(torch.load(model_filename))
310         log_string('loaded_model ' + model_filename)
311     except:
312         need_to_train = True
313
314     ##################################################
315     # Train if necessary
316
317     if need_to_train:
318
319         log_string('training_model ' + model_filename)
320
321         t = time.time()
322
323         train_set = VignetteSet(problem_number,
324                                 args.nb_train_samples, args.batch_size,
325                                 cuda = torch.cuda.is_available(),
326                                 logger = vignette_logger())
327
328         log_string('data_generation {:0.2f} samples / s'.format(
329             train_set.nb_samples / (time.time() - t))
330         )
331
332         train_model(model, train_set)
333         torch.save(model.state_dict(), model_filename)
334         log_string('saved_model ' + model_filename)
335
336         nb_train_errors = nb_errors(model, train_set)
337
338         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
339             problem_number,
340             100 * nb_train_errors / train_set.nb_samples,
341             nb_train_errors,
342             train_set.nb_samples)
343         )
344
345     ##################################################
346     # Test if necessary
347
348     if need_to_train or args.test_loaded_models:
349
350         t = time.time()
351
352         test_set = VignetteSet(problem_number,
353                                args.nb_test_samples, args.batch_size,
354                                cuda = torch.cuda.is_available())
355
356         log_string('data_generation {:0.2f} samples / s'.format(
357             test_set.nb_samples / (time.time() - t))
358         )
359
360         nb_test_errors = nb_errors(model, test_set)
361
362         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
363             problem_number,
364             100 * nb_test_errors / test_set.nb_samples,
365             nb_test_errors,
366             test_set.nb_samples)
367         )
368
369 ######################################################################