Compute the test error when the network is loaded and not trained.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'cnn-svrt.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 args = parser.parse_args()
77
78 ######################################################################
79
80 log_file = open(args.log_file, 'w')
81 pred_log_t = None
82
83 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
84
85 def log_string(s):
86     global pred_log_t
87     t = time.time()
88
89     if pred_log_t is None:
90         elapsed = 'start'
91     else:
92         elapsed = '+{:.02f}s'.format(t - pred_log_t)
93     pred_log_t = t
94     s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s
95     log_file.write(s + '\n')
96     log_file.flush()
97     print(s)
98
99 ######################################################################
100
101 # Afroze's ShallowNet
102
103 #                       map size   nb. maps
104 #                     ----------------------
105 #    input                128x128    1
106 # -- conv(21x21 x 6)   -> 108x108    6
107 # -- max(2x2)          -> 54x54      6
108 # -- conv(19x19 x 16)  -> 36x36      16
109 # -- max(2x2)          -> 18x18      16
110 # -- conv(18x18 x 120) -> 1x1        120
111 # -- reshape           -> 120        1
112 # -- full(120x84)      -> 84         1
113 # -- full(84x2)        -> 2          1
114
115 class AfrozeShallowNet(nn.Module):
116     def __init__(self):
117         super(AfrozeShallowNet, self).__init__()
118         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
119         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
120         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
121         self.fc1 = nn.Linear(120, 84)
122         self.fc2 = nn.Linear(84, 2)
123         self.name = 'shallownet'
124
125     def forward(self, x):
126         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
127         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
128         x = fn.relu(self.conv3(x))
129         x = x.view(-1, 120)
130         x = fn.relu(self.fc1(x))
131         x = self.fc2(x)
132         return x
133
134 ######################################################################
135
136 def train_model(model, train_set):
137     batch_size = args.batch_size
138     criterion = nn.CrossEntropyLoss()
139
140     if torch.cuda.is_available():
141         criterion.cuda()
142
143     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
144
145     for e in range(0, args.nb_epochs):
146         acc_loss = 0.0
147         for b in range(0, train_set.nb_batches):
148             input, target = train_set.get_batch(b)
149             output = model.forward(Variable(input))
150             loss = criterion(output, Variable(target))
151             acc_loss = acc_loss + loss.data[0]
152             model.zero_grad()
153             loss.backward()
154             optimizer.step()
155         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
156
157     return model
158
159 ######################################################################
160
161 def nb_errors(model, data_set):
162     ne = 0
163     for b in range(0, data_set.nb_batches):
164         input, target = data_set.get_batch(b)
165         output = model.forward(Variable(input))
166         wta_prediction = output.data.max(1)[1].view(-1)
167
168         for i in range(0, data_set.batch_size):
169             if wta_prediction[i] != target[i]:
170                 ne = ne + 1
171
172     return ne
173
174 ######################################################################
175
176 for arg in vars(args):
177     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
178
179 ######################################################################
180
181 for problem_number in range(1, 24):
182
183     log_string('**** problem ' + str(problem_number) + ' ****')
184
185     model = AfrozeShallowNet()
186
187     if torch.cuda.is_available():
188         model.cuda()
189
190     model_filename = model.name + '_' + \
191                      str(problem_number) + '_' + \
192                      str(args.nb_train_batches) + '.param'
193
194     nb_parameters = 0
195     for p in model.parameters(): nb_parameters += p.numel()
196     log_string('nb_parameters {:d}'.format(nb_parameters))
197
198     need_to_train = False
199     try:
200         model.load_state_dict(torch.load(model_filename))
201         log_string('loaded_model ' + model_filename)
202     except:
203         need_to_train = True
204
205     if need_to_train:
206
207         log_string('training_model ' + model_filename)
208
209         t = time.time()
210
211         if args.compress_vignettes:
212             train_set = CompressedVignetteSet(problem_number,
213                                               args.nb_train_batches, args.batch_size,
214                                               cuda=torch.cuda.is_available())
215         else:
216             train_set = VignetteSet(problem_number,
217                                     args.nb_train_batches, args.batch_size,
218                                     cuda=torch.cuda.is_available())
219
220         log_string('data_generation {:0.2f} samples / s'.format(
221             (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
222         )
223
224         train_model(model, train_set)
225         torch.save(model.state_dict(), model_filename)
226         log_string('saved_model ' + model_filename)
227
228         nb_train_errors = nb_errors(model, train_set)
229
230         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
231             problem_number,
232             100 * nb_train_errors / train_set.nb_samples,
233             nb_train_errors,
234             train_set.nb_samples)
235         )
236
237     if args.compress_vignettes:
238         test_set = CompressedVignetteSet(problem_number,
239                                          args.nb_test_batches, args.batch_size,
240                                          cuda=torch.cuda.is_available())
241     else:
242         test_set = VignetteSet(problem_number,
243                                args.nb_test_batches, args.batch_size,
244                                cuda=torch.cuda.is_available())
245
246     nb_test_errors = nb_errors(model, test_set)
247
248     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
249         problem_number,
250         100 * nb_test_errors / test_set.nb_samples,
251         nb_test_errors,
252         test_set.nb_samples)
253     )
254
255
256 ######################################################################