Heavy fix.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'cnn-svrt.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 args = parser.parse_args()
77
78 ######################################################################
79
80 log_file = open(args.log_file, 'w')
81
82 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
83
84 def log_string(s):
85     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
86     log_file.write(s + '\n')
87     log_file.flush()
88     print(s)
89
90 ######################################################################
91
92 # Afroze's ShallowNet
93
94 #                       map size   nb. maps
95 #                     ----------------------
96 #    input                128x128    1
97 # -- conv(21x21 x 6)   -> 108x108    6
98 # -- max(2x2)          -> 54x54      6
99 # -- conv(19x19 x 16)  -> 36x36      16
100 # -- max(2x2)          -> 18x18      16
101 # -- conv(18x18 x 120) -> 1x1        120
102 # -- reshape           -> 120        1
103 # -- full(120x84)      -> 84         1
104 # -- full(84x2)        -> 2          1
105
106 class AfrozeShallowNet(nn.Module):
107     def __init__(self):
108         super(AfrozeShallowNet, self).__init__()
109         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
110         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
111         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
112         self.fc1 = nn.Linear(120, 84)
113         self.fc2 = nn.Linear(84, 2)
114         self.name = 'shallownet'
115
116     def forward(self, x):
117         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
118         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
119         x = fn.relu(self.conv3(x))
120         x = x.view(-1, 120)
121         x = fn.relu(self.fc1(x))
122         x = self.fc2(x)
123         return x
124
125 ######################################################################
126
127 def train_model(model, train_set):
128     batch_size = args.batch_size
129     criterion = nn.CrossEntropyLoss()
130
131     if torch.cuda.is_available():
132         criterion.cuda()
133
134     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
135
136     for e in range(0, args.nb_epochs):
137         acc_loss = 0.0
138         for b in range(0, train_set.nb_batches):
139             input, target = train_set.get_batch(b)
140             output = model.forward(Variable(input))
141             loss = criterion(output, Variable(target))
142             acc_loss = acc_loss + loss.data[0]
143             model.zero_grad()
144             loss.backward()
145             optimizer.step()
146         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
147
148     return model
149
150 ######################################################################
151
152 def nb_errors(model, data_set):
153     ne = 0
154     for b in range(0, data_set.nb_batches):
155         input, target = data_set.get_batch(b)
156         output = model.forward(Variable(input))
157         wta_prediction = output.data.max(1)[1].view(-1)
158
159         for i in range(0, data_set.batch_size):
160             if wta_prediction[i] != target[i]:
161                 ne = ne + 1
162
163     return ne
164
165 ######################################################################
166
167 for arg in vars(args):
168     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
169
170 ######################################################################
171
172 for problem_number in range(1, 24):
173
174     model = AfrozeShallowNet()
175
176     if torch.cuda.is_available():
177         model.cuda()
178
179     model_filename = model.name + '_' + \
180                      str(problem_number) + '_' + \
181                      str(args.nb_train_batches) + '.param'
182
183     nb_parameters = 0
184     for p in model.parameters(): nb_parameters += p.numel()
185     log_string('nb_parameters {:d}'.format(nb_parameters))
186
187     need_to_train = False
188     try:
189         model.load_state_dict(torch.load(model_filename))
190         log_string('loaded_model ' + model_filename)
191     except:
192         need_to_train = True
193
194     if need_to_train:
195
196         log_string('training_model ' + model_filename)
197
198         t = time.time()
199
200         if args.compress_vignettes:
201             train_set = CompressedVignetteSet(problem_number,
202                                               args.nb_train_batches, args.batch_size,
203                                               cuda=torch.cuda.is_available())
204             test_set = CompressedVignetteSet(problem_number,
205                                              args.nb_test_batches, args.batch_size,
206                                              cuda=torch.cuda.is_available())
207         else:
208             train_set = VignetteSet(problem_number,
209                                     args.nb_train_batches, args.batch_size,
210                                     cuda=torch.cuda.is_available())
211             test_set = VignetteSet(problem_number,
212                                    args.nb_test_batches, args.batch_size,
213                                    cuda=torch.cuda.is_available())
214
215         log_string('data_generation {:0.2f} samples / s'.format(
216             (train_set.nb_samples + test_set.nb_samples) / (time.time() - t))
217         )
218
219         train_model(model, train_set)
220         torch.save(model.state_dict(), model_filename)
221         log_string('saved_model ' + model_filename)
222
223         nb_train_errors = nb_errors(model, train_set)
224
225         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
226             problem_number,
227             100 * nb_train_errors / train_set.nb_samples,
228             nb_train_errors,
229             train_set.nb_samples)
230         )
231
232         nb_test_errors = nb_errors(model, test_set)
233
234         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
235             problem_number,
236             100 * nb_test_errors / test_set.nb_samples,
237             nb_test_errors,
238             test_set.nb_samples)
239         )
240
241 ######################################################################