Added Afroze's DeepNet.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'default.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 parser.add_argument('--deep_model',
77                     action='store_true', default = False,
78                     help = 'Use Afroze\'s Alexnet-like deep model')
79
80 parser.add_argument('--test_loaded_models',
81                     action='store_true', default = False,
82                     help = 'Should we compute the test errors of loaded models')
83
84 args = parser.parse_args()
85
86 ######################################################################
87
88 log_file = open(args.log_file, 'w')
89 pred_log_t = None
90
91 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
92
93 # Log and prints the string, with a time stamp. Does not log the
94 # remark
95 def log_string(s, remark = ''):
96     global pred_log_t
97
98     t = time.time()
99
100     if pred_log_t is None:
101         elapsed = 'start'
102     else:
103         elapsed = '+{:.02f}s'.format(t - pred_log_t)
104
105     pred_log_t = t
106
107     s = Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
108     log_file.write(s + '\n')
109     log_file.flush()
110     print(s + Fore.CYAN + remark + Style.RESET_ALL)
111
112 ######################################################################
113
114 # Afroze's ShallowNet
115
116 #                       map size   nb. maps
117 #                     ----------------------
118 #    input                128x128    1
119 # -- conv(21x21 x 6)   -> 108x108    6
120 # -- max(2x2)          -> 54x54      6
121 # -- conv(19x19 x 16)  -> 36x36      16
122 # -- max(2x2)          -> 18x18      16
123 # -- conv(18x18 x 120) -> 1x1        120
124 # -- reshape           -> 120        1
125 # -- full(120x84)      -> 84         1
126 # -- full(84x2)        -> 2          1
127
128 class AfrozeShallowNet(nn.Module):
129     def __init__(self):
130         super(AfrozeShallowNet, self).__init__()
131         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
132         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
133         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
134         self.fc1 = nn.Linear(120, 84)
135         self.fc2 = nn.Linear(84, 2)
136         self.name = 'shallownet'
137
138     def forward(self, x):
139         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
140         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
141         x = fn.relu(self.conv3(x))
142         x = x.view(-1, 120)
143         x = fn.relu(self.fc1(x))
144         x = self.fc2(x)
145         return x
146
147 ######################################################################
148
149 # Afroze's DeepNet
150
151 #                       map size   nb. maps
152 #                     ----------------------
153 #    input                128x128    1
154 # -- conv(21x21 x 32 stride=4) -> 28x28    32
155 # -- max(2x2)                  -> 14x14      6
156 # -- conv(7x7 x 96)            -> 8x8      16
157 # -- max(2x2)                  -> 4x4      16
158 # -- conv(5x5 x 96)            -> 26x36      16
159 # -- conv(3x3 x 128)           -> 36x36      16
160 # -- conv(3x3 x 128)           -> 36x36      16
161
162 # -- conv(5x5 x 120) -> 1x1        120
163 # -- reshape           -> 120        1
164 # -- full(3x84)      -> 84         1
165 # -- full(84x2)        -> 2          1
166
167 class AfrozeDeepNet(nn.Module):
168     def __init__(self):
169         super(AfrozeDeepNet, self).__init__()
170         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
171         self.conv2 = nn.Conv2d( 32,  96, kernel_size=5, padding=2)
172         self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
173         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
174         self.conv5 = nn.Conv2d(128,  96, kernel_size=3, padding=1)
175         self.fc1 = nn.Linear(1536, 256)
176         self.fc2 = nn.Linear(256, 256)
177         self.fc3 = nn.Linear(256, 2)
178         self.name = 'deepnet'
179
180     def forward(self, x):
181         x = self.conv1(x)
182         x = fn.max_pool2d(x, kernel_size=2)
183         x = fn.relu(x)
184
185         x = self.conv2(x)
186         x = fn.max_pool2d(x, kernel_size=2)
187         x = fn.relu(x)
188
189         x = self.conv3(x)
190         x = fn.relu(x)
191
192         x = self.conv4(x)
193         x = fn.relu(x)
194
195         x = self.conv5(x)
196         x = fn.max_pool2d(x, kernel_size=2)
197         x = fn.relu(x)
198
199         x = x.view(-1, 1536)
200
201         x = self.fc1(x)
202         x = fn.relu(x)
203
204         x = self.fc2(x)
205         x = fn.relu(x)
206
207         x = self.fc3(x)
208
209         return x
210
211 ######################################################################
212
213 def train_model(model, train_set):
214     batch_size = args.batch_size
215     criterion = nn.CrossEntropyLoss()
216
217     if torch.cuda.is_available():
218         criterion.cuda()
219
220     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
221
222     start_t = time.time()
223
224     for e in range(0, args.nb_epochs):
225         acc_loss = 0.0
226         for b in range(0, train_set.nb_batches):
227             input, target = train_set.get_batch(b)
228             output = model.forward(Variable(input))
229             loss = criterion(output, Variable(target))
230             acc_loss = acc_loss + loss.data[0]
231             model.zero_grad()
232             loss.backward()
233             optimizer.step()
234         dt = (time.time() - start_t) / (e + 1)
235         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
236                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
237
238     return model
239
240 ######################################################################
241
242 def nb_errors(model, data_set):
243     ne = 0
244     for b in range(0, data_set.nb_batches):
245         input, target = data_set.get_batch(b)
246         output = model.forward(Variable(input))
247         wta_prediction = output.data.max(1)[1].view(-1)
248
249         for i in range(0, data_set.batch_size):
250             if wta_prediction[i] != target[i]:
251                 ne = ne + 1
252
253     return ne
254
255 ######################################################################
256
257 for arg in vars(args):
258     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
259
260 ######################################################################
261
262 for problem_number in range(1, 24):
263
264     log_string('**** problem ' + str(problem_number) + ' ****')
265
266     if args.deep_model:
267         model = AfrozeDeepNet()
268     else:
269         model = AfrozeShallowNet()
270
271     if torch.cuda.is_available():
272         model.cuda()
273
274     model_filename = model.name + '_' + \
275                      str(problem_number) + '_' + \
276                      str(args.nb_train_batches) + '.param'
277
278     nb_parameters = 0
279     for p in model.parameters(): nb_parameters += p.numel()
280     log_string('nb_parameters {:d}'.format(nb_parameters))
281
282     need_to_train = False
283     try:
284         model.load_state_dict(torch.load(model_filename))
285         log_string('loaded_model ' + model_filename)
286     except:
287         need_to_train = True
288
289     if need_to_train:
290
291         log_string('training_model ' + model_filename)
292
293         t = time.time()
294
295         if args.compress_vignettes:
296             train_set = CompressedVignetteSet(problem_number,
297                                               args.nb_train_batches, args.batch_size,
298                                               cuda=torch.cuda.is_available())
299         else:
300             train_set = VignetteSet(problem_number,
301                                     args.nb_train_batches, args.batch_size,
302                                     cuda=torch.cuda.is_available())
303
304         log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
305
306         train_model(model, train_set)
307         torch.save(model.state_dict(), model_filename)
308         log_string('saved_model ' + model_filename)
309
310         nb_train_errors = nb_errors(model, train_set)
311
312         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
313             problem_number,
314             100 * nb_train_errors / train_set.nb_samples,
315             nb_train_errors,
316             train_set.nb_samples)
317         )
318
319     if need_to_train or args.test_loaded_models:
320
321         t = time.time()
322
323         if args.compress_vignettes:
324             test_set = CompressedVignetteSet(problem_number,
325                                              args.nb_test_batches, args.batch_size,
326                                              cuda=torch.cuda.is_available())
327         else:
328             test_set = VignetteSet(problem_number,
329                                    args.nb_test_batches, args.batch_size,
330                                    cuda=torch.cuda.is_available())
331
332         log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
333
334         nb_test_errors = nb_errors(model, test_set)
335
336         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
337             problem_number,
338             100 * nb_test_errors / test_set.nb_samples,
339             nb_test_errors,
340             test_set.nb_samples)
341         )
342
343 ######################################################################