Cosmetics.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 import torch
31
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
35 from torch import nn
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
38
39 import svrt
40
41 ######################################################################
42
43 parser = argparse.ArgumentParser(
44     description = 'Simple convnet test on the SVRT.',
45     formatter_class = argparse.ArgumentDefaultsHelpFormatter
46 )
47
48 parser.add_argument('--nb_train_batches',
49                     type = int, default = 1000,
50                     help = 'How many samples for train')
51
52 parser.add_argument('--nb_test_batches',
53                     type = int, default = 100,
54                     help = 'How many samples for test')
55
56 parser.add_argument('--nb_epochs',
57                     type = int, default = 50,
58                     help = 'How many training epochs')
59
60 parser.add_argument('--batch_size',
61                     type = int, default = 100,
62                     help = 'Mini-batch size')
63
64 parser.add_argument('--log_file',
65                     type = str, default = 'cnn-svrt.log',
66                     help = 'Log file name')
67
68 parser.add_argument('--compress_vignettes',
69                     action='store_true', default = False,
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 args = parser.parse_args()
73
74 ######################################################################
75
76 log_file = open(args.log_file, 'w')
77
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
79
80 def log_string(s):
81     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82     log_file.write(s + '\n')
83     log_file.flush()
84     print(s)
85
86 ######################################################################
87
88 class VignetteSet:
89     def __init__(self, problem_number, nb_batches):
90         self.batch_size = args.batch_size
91         self.problem_number = problem_number
92         self.nb_batches = nb_batches
93         self.nb_samples = self.nb_batches * self.batch_size
94         self.targets = []
95         self.inputs = []
96
97         acc = 0.0
98         acc_sq = 0.0
99
100         for b in range(0, self.nb_batches):
101             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
102             input = svrt.generate_vignettes(problem_number, target)
103             input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
104             if torch.cuda.is_available():
105                 input = input.cuda()
106                 target = target.cuda()
107             acc += input.sum() / input.numel()
108             acc_sq += input.pow(2).sum() /  input.numel()
109             self.targets.append(target)
110             self.inputs.append(input)
111
112         mean = acc / self.nb_batches
113         std = math.sqrt(acc_sq / self.nb_batches - mean * mean)
114         for b in range(0, self.nb_batches):
115             self.inputs[b].sub_(mean).div_(std)
116
117     def get_batch(self, b):
118         return self.inputs[b], self.targets[b]
119
120 ######################################################################
121
122 class CompressedVignetteSet:
123     def __init__(self, problem_number, nb_batches):
124         self.batch_size = args.batch_size
125         self.problem_number = problem_number
126         self.nb_batches = nb_batches
127         self.nb_samples = self.nb_batches * self.batch_size
128         self.targets = []
129         self.input_storages = []
130
131         acc = 0.0
132         acc_sq = 0.0
133         for b in range(0, self.nb_batches):
134             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
135             input = svrt.generate_vignettes(problem_number, target)
136             acc += input.float().sum() / input.numel()
137             acc_sq += input.float().pow(2).sum() /  input.numel()
138             self.targets.append(target)
139             self.input_storages.append(svrt.compress(input.storage()))
140
141         self.mean = acc / self.nb_batches
142         self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
143
144     def get_batch(self, b):
145         input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
146         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
147         target = self.targets[b]
148
149         if torch.cuda.is_available():
150             input = input.cuda()
151             target = target.cuda()
152
153         return input, target
154
155 ######################################################################
156
157 # Afroze's ShallowNet
158
159 #                       map size   nb. maps
160 #                     ----------------------
161 #    input                128x128    1
162 # -- conv(21x21 x 6)   -> 108x108    6
163 # -- max(2x2)          -> 54x54      6
164 # -- conv(19x19 x 16)  -> 36x36      16
165 # -- max(2x2)          -> 18x18      16
166 # -- conv(18x18 x 120) -> 1x1        120
167 # -- reshape           -> 120        1
168 # -- full(120x84)      -> 84         1
169 # -- full(84x2)        -> 2          1
170
171 class AfrozeShallowNet(nn.Module):
172     def __init__(self):
173         super(AfrozeShallowNet, self).__init__()
174         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
175         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
176         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
177         self.fc1 = nn.Linear(120, 84)
178         self.fc2 = nn.Linear(84, 2)
179
180     def forward(self, x):
181         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
182         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
183         x = fn.relu(self.conv3(x))
184         x = x.view(-1, 120)
185         x = fn.relu(self.fc1(x))
186         x = self.fc2(x)
187         return x
188
189 def train_model(model, train_set):
190     batch_size = args.batch_size
191     criterion = nn.CrossEntropyLoss()
192
193     if torch.cuda.is_available():
194         criterion.cuda()
195
196     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
197
198     for e in range(0, args.nb_epochs):
199         acc_loss = 0.0
200         for b in range(0, train_set.nb_batches):
201             input, target = train_set.get_batch(b)
202             output = model.forward(Variable(input))
203             loss = criterion(output, Variable(target))
204             acc_loss = acc_loss + loss.data[0]
205             model.zero_grad()
206             loss.backward()
207             optimizer.step()
208         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
209
210     return model
211
212 ######################################################################
213
214 def nb_errors(model, data_set):
215     ne = 0
216     for b in range(0, data_set.nb_batches):
217         input, target = data_set.get_batch(b)
218         output = model.forward(Variable(input))
219         wta_prediction = output.data.max(1)[1].view(-1)
220
221         for i in range(0, data_set.batch_size):
222             if wta_prediction[i] != target[i]:
223                 ne = ne + 1
224
225     return ne
226
227 ######################################################################
228
229 for arg in vars(args):
230     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
231
232 for problem_number in range(1, 24):
233     if args.compress_vignettes:
234         train_set = CompressedVignetteSet(problem_number, args.nb_train_batches)
235         test_set = CompressedVignetteSet(problem_number, args.nb_test_batches)
236     else:
237         train_set = VignetteSet(problem_number, args.nb_train_batches)
238         test_set = VignetteSet(problem_number, args.nb_test_batches)
239
240     model = AfrozeShallowNet()
241
242     if torch.cuda.is_available():
243         model.cuda()
244
245     nb_parameters = 0
246     for p in model.parameters():
247         nb_parameters += p.numel()
248     log_string('nb_parameters {:d}'.format(nb_parameters))
249
250     model_filename = 'model_' + str(problem_number) + '.param'
251
252     try:
253         model.load_state_dict(torch.load(model_filename))
254         log_string('loaded_model ' + model_filename)
255     except:
256         log_string('training_model')
257         train_model(model, train_set)
258         torch.save(model.state_dict(), model_filename)
259         log_string('saved_model ' + model_filename)
260
261     nb_train_errors = nb_errors(model, train_set)
262
263     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
264         problem_number,
265         100 * nb_train_errors / train_set.nb_samples,
266         nb_train_errors,
267         train_set.nb_samples)
268     )
269
270     nb_test_errors = nb_errors(model, test_set)
271
272     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
273         problem_number,
274         100 * nb_test_errors / test_set.nb_samples,
275         nb_test_errors,
276         test_set.nb_samples)
277     )
278
279 ######################################################################