59133455353fc3c27c97466fc5595c0b8b9abbb0
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'cnn-svrt.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 parser.add_argument('--test_loaded_models',
77                     action='store_true', default = False,
78                     help = 'Should we compute the test error of models we load')
79
80 args = parser.parse_args()
81
82 ######################################################################
83
84 log_file = open(args.log_file, 'w')
85 pred_log_t = None
86
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
88
89 def log_string(s):
90     global pred_log_t
91     t = time.time()
92
93     if pred_log_t is None:
94         elapsed = 'start'
95     else:
96         elapsed = '+{:.02f}s'.format(t - pred_log_t)
97     pred_log_t = t
98     s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
99     log_file.write(s + '\n')
100     log_file.flush()
101     print(s)
102
103 ######################################################################
104
105 # Afroze's ShallowNet
106
107 #                       map size   nb. maps
108 #                     ----------------------
109 #    input                128x128    1
110 # -- conv(21x21 x 6)   -> 108x108    6
111 # -- max(2x2)          -> 54x54      6
112 # -- conv(19x19 x 16)  -> 36x36      16
113 # -- max(2x2)          -> 18x18      16
114 # -- conv(18x18 x 120) -> 1x1        120
115 # -- reshape           -> 120        1
116 # -- full(120x84)      -> 84         1
117 # -- full(84x2)        -> 2          1
118
119 class AfrozeShallowNet(nn.Module):
120     def __init__(self):
121         super(AfrozeShallowNet, self).__init__()
122         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
123         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
124         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
125         self.fc1 = nn.Linear(120, 84)
126         self.fc2 = nn.Linear(84, 2)
127         self.name = 'shallownet'
128
129     def forward(self, x):
130         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
131         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
132         x = fn.relu(self.conv3(x))
133         x = x.view(-1, 120)
134         x = fn.relu(self.fc1(x))
135         x = self.fc2(x)
136         return x
137
138 ######################################################################
139
140 def train_model(model, train_set):
141     batch_size = args.batch_size
142     criterion = nn.CrossEntropyLoss()
143
144     if torch.cuda.is_available():
145         criterion.cuda()
146
147     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
148
149     for e in range(0, args.nb_epochs):
150         acc_loss = 0.0
151         for b in range(0, train_set.nb_batches):
152             input, target = train_set.get_batch(b)
153             output = model.forward(Variable(input))
154             loss = criterion(output, Variable(target))
155             acc_loss = acc_loss + loss.data[0]
156             model.zero_grad()
157             loss.backward()
158             optimizer.step()
159         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
160
161     return model
162
163 ######################################################################
164
165 def nb_errors(model, data_set):
166     ne = 0
167     for b in range(0, data_set.nb_batches):
168         input, target = data_set.get_batch(b)
169         output = model.forward(Variable(input))
170         wta_prediction = output.data.max(1)[1].view(-1)
171
172         for i in range(0, data_set.batch_size):
173             if wta_prediction[i] != target[i]:
174                 ne = ne + 1
175
176     return ne
177
178 ######################################################################
179
180 for arg in vars(args):
181     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
182
183 ######################################################################
184
185 for problem_number in range(1, 24):
186
187     log_string('**** problem ' + str(problem_number) + ' ****')
188
189     model = AfrozeShallowNet()
190
191     if torch.cuda.is_available():
192         model.cuda()
193
194     model_filename = model.name + '_' + \
195                      str(problem_number) + '_' + \
196                      str(args.nb_train_batches) + '.param'
197
198     nb_parameters = 0
199     for p in model.parameters(): nb_parameters += p.numel()
200     log_string('nb_parameters {:d}'.format(nb_parameters))
201
202     need_to_train = False
203     try:
204         model.load_state_dict(torch.load(model_filename))
205         log_string('loaded_model ' + model_filename)
206     except:
207         need_to_train = True
208
209     if need_to_train:
210
211         log_string('training_model ' + model_filename)
212
213         t = time.time()
214
215         if args.compress_vignettes:
216             train_set = CompressedVignetteSet(problem_number,
217                                               args.nb_train_batches, args.batch_size,
218                                               cuda=torch.cuda.is_available())
219         else:
220             train_set = VignetteSet(problem_number,
221                                     args.nb_train_batches, args.batch_size,
222                                     cuda=torch.cuda.is_available())
223
224         log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
225
226         train_model(model, train_set)
227         torch.save(model.state_dict(), model_filename)
228         log_string('saved_model ' + model_filename)
229
230         nb_train_errors = nb_errors(model, train_set)
231
232         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
233             problem_number,
234             100 * nb_train_errors / train_set.nb_samples,
235             nb_train_errors,
236             train_set.nb_samples)
237         )
238
239     if need_to_train or args.test_loaded_models:
240
241         t = time.time()
242
243         if args.compress_vignettes:
244             test_set = CompressedVignetteSet(problem_number,
245                                              args.nb_test_batches, args.batch_size,
246                                              cuda=torch.cuda.is_available())
247         else:
248             test_set = VignetteSet(problem_number,
249                                    args.nb_test_batches, args.batch_size,
250                                    cuda=torch.cuda.is_available())
251
252         log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
253
254         nb_test_errors = nb_errors(model, test_set)
255
256         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
257             problem_number,
258             100 * nb_test_errors / test_set.nb_samples,
259             nb_test_errors,
260             test_set.nb_samples)
261         )
262
263 ######################################################################