Minor ETA fix.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'cnn-svrt.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 parser.add_argument('--test_loaded_models',
77                     action='store_true', default = False,
78                     help = 'Should we compute the test error of models we load')
79
80 args = parser.parse_args()
81
82 ######################################################################
83
84 log_file = open(args.log_file, 'w')
85 pred_log_t = None
86
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88
89 def log_string(s):
90     global pred_log_t
91     t = time.time()
92
93     if pred_log_t is None:
94         elapsed = 'start'
95     else:
96         elapsed = '+{:.02f}s'.format(t - pred_log_t)
97     pred_log_t = t
98     s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
99     log_file.write(s + '\n')
100     log_file.flush()
101     print(s)
102
103 ######################################################################
104
105 # Afroze's ShallowNet
106
107 #                       map size   nb. maps
108 #                     ----------------------
109 #    input                128x128    1
110 # -- conv(21x21 x 6)   -> 108x108    6
111 # -- max(2x2)          -> 54x54      6
112 # -- conv(19x19 x 16)  -> 36x36      16
113 # -- max(2x2)          -> 18x18      16
114 # -- conv(18x18 x 120) -> 1x1        120
115 # -- reshape           -> 120        1
116 # -- full(120x84)      -> 84         1
117 # -- full(84x2)        -> 2          1
118
119 class AfrozeShallowNet(nn.Module):
120     def __init__(self):
121         super(AfrozeShallowNet, self).__init__()
122         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
123         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
124         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
125         self.fc1 = nn.Linear(120, 84)
126         self.fc2 = nn.Linear(84, 2)
127         self.name = 'shallownet'
128
129     def forward(self, x):
130         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
131         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
132         x = fn.relu(self.conv3(x))
133         x = x.view(-1, 120)
134         x = fn.relu(self.fc1(x))
135         x = self.fc2(x)
136         return x
137
138 ######################################################################
139
140 def train_model(model, train_set):
141     batch_size = args.batch_size
142     criterion = nn.CrossEntropyLoss()
143
144     if torch.cuda.is_available():
145         criterion.cuda()
146
147     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
148
149     start_t = time.time()
150
151     for e in range(0, args.nb_epochs):
152         acc_loss = 0.0
153         for b in range(0, train_set.nb_batches):
154             input, target = train_set.get_batch(b)
155             output = model.forward(Variable(input))
156             loss = criterion(output, Variable(target))
157             acc_loss = acc_loss + loss.data[0]
158             model.zero_grad()
159             loss.backward()
160             optimizer.step()
161         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
162         dt = (time.time() - start_t) / (e + 1)
163         print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL)
164
165     return model
166
167 ######################################################################
168
169 def nb_errors(model, data_set):
170     ne = 0
171     for b in range(0, data_set.nb_batches):
172         input, target = data_set.get_batch(b)
173         output = model.forward(Variable(input))
174         wta_prediction = output.data.max(1)[1].view(-1)
175
176         for i in range(0, data_set.batch_size):
177             if wta_prediction[i] != target[i]:
178                 ne = ne + 1
179
180     return ne
181
182 ######################################################################
183
184 for arg in vars(args):
185     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
186
187 ######################################################################
188
189 for problem_number in range(1, 24):
190
191     log_string('**** problem ' + str(problem_number) + ' ****')
192
193     model = AfrozeShallowNet()
194
195     if torch.cuda.is_available():
196         model.cuda()
197
198     model_filename = model.name + '_' + \
199                      str(problem_number) + '_' + \
200                      str(args.nb_train_batches) + '.param'
201
202     nb_parameters = 0
203     for p in model.parameters(): nb_parameters += p.numel()
204     log_string('nb_parameters {:d}'.format(nb_parameters))
205
206     need_to_train = False
207     try:
208         model.load_state_dict(torch.load(model_filename))
209         log_string('loaded_model ' + model_filename)
210     except:
211         need_to_train = True
212
213     if need_to_train:
214
215         log_string('training_model ' + model_filename)
216
217         t = time.time()
218
219         if args.compress_vignettes:
220             train_set = CompressedVignetteSet(problem_number,
221                                               args.nb_train_batches, args.batch_size,
222                                               cuda=torch.cuda.is_available())
223         else:
224             train_set = VignetteSet(problem_number,
225                                     args.nb_train_batches, args.batch_size,
226                                     cuda=torch.cuda.is_available())
227
228         log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
229
230         train_model(model, train_set)
231         torch.save(model.state_dict(), model_filename)
232         log_string('saved_model ' + model_filename)
233
234         nb_train_errors = nb_errors(model, train_set)
235
236         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
237             problem_number,
238             100 * nb_train_errors / train_set.nb_samples,
239             nb_train_errors,
240             train_set.nb_samples)
241         )
242
243     if need_to_train or args.test_loaded_models:
244
245         t = time.time()
246
247         if args.compress_vignettes:
248             test_set = CompressedVignetteSet(problem_number,
249                                              args.nb_test_batches, args.batch_size,
250                                              cuda=torch.cuda.is_available())
251         else:
252             test_set = VignetteSet(problem_number,
253                                    args.nb_test_batches, args.batch_size,
254                                    cuda=torch.cuda.is_available())
255
256         log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
257
258         nb_test_errors = nb_errors(model, test_set)
259
260         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
261             problem_number,
262             100 * nb_test_errors / test_set.nb_samples,
263             nb_test_errors,
264             test_set.nb_samples)
265         )
266
267 ######################################################################