Cosmetics.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'default.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 parser.add_argument('--test_loaded_models',
77                     action='store_true', default = False,
78                     help = 'Should we compute the test errors of loaded models')
79
80 args = parser.parse_args()
81
82 ######################################################################
83
84 log_file = open(args.log_file, 'w')
85 pred_log_t = None
86
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88
89 def log_string(s):
90     global pred_log_t
91
92     t = time.time()
93
94     if pred_log_t is None:
95         elapsed = 'start'
96     else:
97         elapsed = '+{:.02f}s'.format(t - pred_log_t)
98
99     pred_log_t = t
100
101     s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
102     log_file.write(s + '\n')
103     log_file.flush()
104     print(s)
105
106 ######################################################################
107
108 # Afroze's ShallowNet
109
110 #                       map size   nb. maps
111 #                     ----------------------
112 #    input                128x128    1
113 # -- conv(21x21 x 6)   -> 108x108    6
114 # -- max(2x2)          -> 54x54      6
115 # -- conv(19x19 x 16)  -> 36x36      16
116 # -- max(2x2)          -> 18x18      16
117 # -- conv(18x18 x 120) -> 1x1        120
118 # -- reshape           -> 120        1
119 # -- full(120x84)      -> 84         1
120 # -- full(84x2)        -> 2          1
121
122 class AfrozeShallowNet(nn.Module):
123     def __init__(self):
124         super(AfrozeShallowNet, self).__init__()
125         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
126         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
127         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
128         self.fc1 = nn.Linear(120, 84)
129         self.fc2 = nn.Linear(84, 2)
130         self.name = 'shallownet'
131
132     def forward(self, x):
133         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
134         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
135         x = fn.relu(self.conv3(x))
136         x = x.view(-1, 120)
137         x = fn.relu(self.fc1(x))
138         x = self.fc2(x)
139         return x
140
141 ######################################################################
142
143 def train_model(model, train_set):
144     batch_size = args.batch_size
145     criterion = nn.CrossEntropyLoss()
146
147     if torch.cuda.is_available():
148         criterion.cuda()
149
150     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
151
152     start_t = time.time()
153
154     for e in range(0, args.nb_epochs):
155         acc_loss = 0.0
156         for b in range(0, train_set.nb_batches):
157             input, target = train_set.get_batch(b)
158             output = model.forward(Variable(input))
159             loss = criterion(output, Variable(target))
160             acc_loss = acc_loss + loss.data[0]
161             model.zero_grad()
162             loss.backward()
163             optimizer.step()
164         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
165         dt = (time.time() - start_t) / (e + 1)
166         print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL)
167
168     return model
169
170 ######################################################################
171
172 def nb_errors(model, data_set):
173     ne = 0
174     for b in range(0, data_set.nb_batches):
175         input, target = data_set.get_batch(b)
176         output = model.forward(Variable(input))
177         wta_prediction = output.data.max(1)[1].view(-1)
178
179         for i in range(0, data_set.batch_size):
180             if wta_prediction[i] != target[i]:
181                 ne = ne + 1
182
183     return ne
184
185 ######################################################################
186
187 for arg in vars(args):
188     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
189
190 ######################################################################
191
192 for problem_number in range(1, 24):
193
194     log_string('**** problem ' + str(problem_number) + ' ****')
195
196     model = AfrozeShallowNet()
197
198     if torch.cuda.is_available():
199         model.cuda()
200
201     model_filename = model.name + '_' + \
202                      str(problem_number) + '_' + \
203                      str(args.nb_train_batches) + '.param'
204
205     nb_parameters = 0
206     for p in model.parameters(): nb_parameters += p.numel()
207     log_string('nb_parameters {:d}'.format(nb_parameters))
208
209     need_to_train = False
210     try:
211         model.load_state_dict(torch.load(model_filename))
212         log_string('loaded_model ' + model_filename)
213     except:
214         need_to_train = True
215
216     if need_to_train:
217
218         log_string('training_model ' + model_filename)
219
220         t = time.time()
221
222         if args.compress_vignettes:
223             train_set = CompressedVignetteSet(problem_number,
224                                               args.nb_train_batches, args.batch_size,
225                                               cuda=torch.cuda.is_available())
226         else:
227             train_set = VignetteSet(problem_number,
228                                     args.nb_train_batches, args.batch_size,
229                                     cuda=torch.cuda.is_available())
230
231         log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
232
233         train_model(model, train_set)
234         torch.save(model.state_dict(), model_filename)
235         log_string('saved_model ' + model_filename)
236
237         nb_train_errors = nb_errors(model, train_set)
238
239         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
240             problem_number,
241             100 * nb_train_errors / train_set.nb_samples,
242             nb_train_errors,
243             train_set.nb_samples)
244         )
245
246     if need_to_train or args.test_loaded_models:
247
248         t = time.time()
249
250         if args.compress_vignettes:
251             test_set = CompressedVignetteSet(problem_number,
252                                              args.nb_test_batches, args.batch_size,
253                                              cuda=torch.cuda.is_available())
254         else:
255             test_set = VignetteSet(problem_number,
256                                    args.nb_test_batches, args.batch_size,
257                                    cuda=torch.cuda.is_available())
258
259         log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
260
261         nb_test_errors = nb_errors(model, test_set)
262
263         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
264             problem_number,
265             100 * nb_test_errors / test_set.nb_samples,
266             nb_test_errors,
267             test_set.nb_samples)
268         )
269
270 ######################################################################