bbce4c92e6426d48d15676372822ff86962066ed
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 import torch
31
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
35 from torch import nn
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
38
39 import svrt
40
41 ######################################################################
42
43 parser = argparse.ArgumentParser(
44     description = 'Simple convnet test on the SVRT.',
45     formatter_class = argparse.ArgumentDefaultsHelpFormatter
46 )
47
48 parser.add_argument('--nb_train_batches',
49                     type = int, default = 1000,
50                     help = 'How many samples for train')
51
52 parser.add_argument('--nb_test_batches',
53                     type = int, default = 100,
54                     help = 'How many samples for test')
55
56 parser.add_argument('--nb_epochs',
57                     type = int, default = 50,
58                     help = 'How many training epochs')
59
60 parser.add_argument('--batch_size',
61                     type = int, default = 100,
62                     help = 'Mini-batch size')
63
64 parser.add_argument('--log_file',
65                     type = str, default = 'cnn-svrt.log',
66                     help = 'Log file name')
67
68 parser.add_argument('--compress_vignettes',
69                     action='store_true', default = False,
70                     help = 'Should we use lossless compression of vignette to reduce the memory footprint')
71
72 args = parser.parse_args()
73
74 ######################################################################
75
76 log_file = open(args.log_file, 'w')
77
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
79
80 def log_string(s):
81     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82     log_file.write(s + '\n')
83     log_file.flush()
84     print(s)
85
86 ######################################################################
87
88 class VignetteSet:
89     def __init__(self, problem_number, nb_batches):
90         self.batch_size = args.batch_size
91         self.problem_number = problem_number
92         self.nb_batches = nb_batches
93         self.nb_samples = self.nb_batches * self.batch_size
94         self.targets = []
95         self.inputs = []
96
97         acc = 0.0
98         acc_sq = 0.0
99
100         for k in range(0, self.nb_batches):
101             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
102             input = svrt.generate_vignettes(problem_number, target)
103             input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
104             if torch.cuda.is_available():
105                 input = input.cuda()
106                 target = target.cuda()
107             acc += input.float().sum() / input.numel()
108             acc_sq += input.float().pow(2).sum() /  input.numel()
109             self.targets.append(target)
110             self.inputs.append(input)
111
112         mean = acc / self.nb_batches
113         std = math.sqrt(acc_sq / self.nb_batches - mean * mean)
114         for k in range(0, self.nb_batches):
115             self.inputs[k].sub_(mean).div_(std)
116
117     def get_batch(self, b):
118         return self.inputs[b], self.targets[b]
119
120 class CompressedVignetteSet:
121     def __init__(self, problem_number, nb_batches):
122         self.batch_size = args.batch_size
123         self.problem_number = problem_number
124         self.nb_batches = nb_batches
125         self.nb_samples = self.nb_batches * self.batch_size
126         self.targets = []
127         self.input_storages = []
128
129         acc = 0.0
130         acc_sq = 0.0
131         for k in range(0, self.nb_batches):
132             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
133             input = svrt.generate_vignettes(problem_number, target)
134             acc += input.float().sum() / input.numel()
135             acc_sq += input.float().pow(2).sum() /  input.numel()
136             self.targets.append(target)
137             self.input_storages.append(svrt.compress(input.storage()))
138
139         self.mean = acc / self.nb_batches
140         self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
141
142     def get_batch(self, b):
143         input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
144         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
145         target = self.targets[b]
146
147         if torch.cuda.is_available():
148             input = input.cuda()
149             target = target.cuda()
150
151         return input, target
152
153 ######################################################################
154
155 # Afroze's ShallowNet
156
157 #                       map size   nb. maps
158 #                     ----------------------
159 #    input                128x128    1
160 # -- conv(21x21 x 6)   -> 108x108    6
161 # -- max(2x2)          -> 54x54      6
162 # -- conv(19x19 x 16)  -> 36x36      16
163 # -- max(2x2)          -> 18x18      16
164 # -- conv(18x18 x 120) -> 1x1        120
165 # -- reshape           -> 120        1
166 # -- full(120x84)      -> 84         1
167 # -- full(84x2)        -> 2          1
168
169 class AfrozeShallowNet(nn.Module):
170     def __init__(self):
171         super(AfrozeShallowNet, self).__init__()
172         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
173         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
174         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
175         self.fc1 = nn.Linear(120, 84)
176         self.fc2 = nn.Linear(84, 2)
177
178     def forward(self, x):
179         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
180         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
181         x = fn.relu(self.conv3(x))
182         x = x.view(-1, 120)
183         x = fn.relu(self.fc1(x))
184         x = self.fc2(x)
185         return x
186
187 def train_model(model, train_set):
188     batch_size = args.batch_size
189     criterion = nn.CrossEntropyLoss()
190
191     if torch.cuda.is_available():
192         criterion.cuda()
193
194     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
195
196     for k in range(0, args.nb_epochs):
197         acc_loss = 0.0
198         for b in range(0, train_set.nb_batches):
199             input, target = train_set.get_batch(b)
200             output = model.forward(Variable(input))
201             loss = criterion(output, Variable(target))
202             acc_loss = acc_loss + loss.data[0]
203             model.zero_grad()
204             loss.backward()
205             optimizer.step()
206         log_string('train_loss {:d} {:f}'.format(k, acc_loss))
207
208     return model
209
210 ######################################################################
211
212 def nb_errors(model, data_set):
213     ne = 0
214     for b in range(0, data_set.nb_batches):
215         input, target = data_set.get_batch(b)
216         output = model.forward(Variable(input))
217         wta_prediction = output.data.max(1)[1].view(-1)
218
219         for i in range(0, data_set.batch_size):
220             if wta_prediction[i] != target[i]:
221                 ne = ne + 1
222
223     return ne
224
225 ######################################################################
226
227 for arg in vars(args):
228     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
229
230 for problem_number in range(1, 24):
231     if args.compress_vignettes:
232         train_set = CompressedVignetteSet(problem_number, args.nb_train_batches)
233         test_set = CompressedVignetteSet(problem_number, args.nb_test_batches)
234     else:
235         train_set = VignetteSet(problem_number, args.nb_train_batches)
236         test_set = VignetteSet(problem_number, args.nb_test_batches)
237
238     model = AfrozeShallowNet()
239
240     if torch.cuda.is_available():
241         model.cuda()
242
243     nb_parameters = 0
244     for p in model.parameters():
245         nb_parameters += p.numel()
246     log_string('nb_parameters {:d}'.format(nb_parameters))
247
248     model_filename = 'model_' + str(problem_number) + '.param'
249
250     try:
251         model.load_state_dict(torch.load(model_filename))
252         log_string('loaded_model ' + model_filename)
253     except:
254         log_string('training_model')
255         train_model(model, train_set)
256         torch.save(model.state_dict(), model_filename)
257         log_string('saved_model ' + model_filename)
258
259     nb_train_errors = nb_errors(model, train_set)
260
261     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
262         problem_number,
263         100 * nb_train_errors / train_set.nb_samples,
264         nb_train_errors,
265         train_set.nb_samples)
266     )
267
268     nb_test_errors = nb_errors(model, test_set)
269
270     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
271         problem_number,
272         100 * nb_test_errors / test_set.nb_samples,
273         nb_test_errors,
274         test_set.nb_samples)
275     )
276
277 ######################################################################