7bef242de186a1bbcbad4a84444c5c3311a9445e
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'cnn-svrt.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 args = parser.parse_args()
77
78 ######################################################################
79
80 log_file = open(args.log_file, 'w')
81
82 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
83
84 def log_string(s):
85     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
86     log_file.write(s + '\n')
87     log_file.flush()
88     print(s)
89
90 ######################################################################
91
92 # Afroze's ShallowNet
93
94 #                       map size   nb. maps
95 #                     ----------------------
96 #    input                128x128    1
97 # -- conv(21x21 x 6)   -> 108x108    6
98 # -- max(2x2)          -> 54x54      6
99 # -- conv(19x19 x 16)  -> 36x36      16
100 # -- max(2x2)          -> 18x18      16
101 # -- conv(18x18 x 120) -> 1x1        120
102 # -- reshape           -> 120        1
103 # -- full(120x84)      -> 84         1
104 # -- full(84x2)        -> 2          1
105
106 class AfrozeShallowNet(nn.Module):
107     def __init__(self):
108         super(AfrozeShallowNet, self).__init__()
109         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
110         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
111         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
112         self.fc1 = nn.Linear(120, 84)
113         self.fc2 = nn.Linear(84, 2)
114         self.name = 'shallownet'
115
116     def forward(self, x):
117         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
118         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
119         x = fn.relu(self.conv3(x))
120         x = x.view(-1, 120)
121         x = fn.relu(self.fc1(x))
122         x = self.fc2(x)
123         return x
124
125 ######################################################################
126
127 def train_model(model, train_set):
128     batch_size = args.batch_size
129     criterion = nn.CrossEntropyLoss()
130
131     if torch.cuda.is_available():
132         criterion.cuda()
133
134     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
135
136     for e in range(0, args.nb_epochs):
137         acc_loss = 0.0
138         for b in range(0, train_set.nb_batches):
139             input, target = train_set.get_batch(b)
140             output = model.forward(Variable(input))
141             loss = criterion(output, Variable(target))
142             acc_loss = acc_loss + loss.data[0]
143             model.zero_grad()
144             loss.backward()
145             optimizer.step()
146         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
147
148     return model
149
150 ######################################################################
151
152 def nb_errors(model, data_set):
153     ne = 0
154     for b in range(0, data_set.nb_batches):
155         input, target = data_set.get_batch(b)
156         output = model.forward(Variable(input))
157         wta_prediction = output.data.max(1)[1].view(-1)
158
159         for i in range(0, data_set.batch_size):
160             if wta_prediction[i] != target[i]:
161                 ne = ne + 1
162
163     return ne
164
165 ######################################################################
166
167 for arg in vars(args):
168     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
169
170 ######################################################################
171
172 for problem_number in range(1, 24):
173
174     model = AfrozeShallowNet()
175
176     if torch.cuda.is_available():
177         model.cuda()
178
179     model_filename = model.name + '_' + \
180                      str(problem_number) + '_' + \
181                      str(args.nb_train_batches) + '.param'
182
183     nb_parameters = 0
184     for p in model.parameters(): nb_parameters += p.numel()
185     log_string('nb_parameters {:d}'.format(nb_parameters))
186
187     try:
188
189         model.load_state_dict(torch.load(model_filename))
190         log_string('loaded_model ' + model_filename)
191
192     except:
193
194         log_string('training_model ' + model_filename)
195
196         if args.compress_vignettes:
197             train_set = CompressedVignetteSet(problem_number,
198                                               args.nb_train_batches, args.batch_size,
199                                               cuda=torch.cuda.is_available())
200             test_set = CompressedVignetteSet(problem_number,
201                                              args.nb_test_batches, args.batch_size,
202                                              cuda=torch.cuda.is_available())
203         else:
204             train_set = VignetteSet(problem_number,
205                                     args.nb_train_batches, args.batch_size,
206                                     cuda=torch.cuda.is_available())
207             test_set = VignetteSet(problem_number,
208                                    args.nb_test_batches, args.batch_size,
209                                    cuda=torch.cuda.is_available())
210
211         train_model(model, train_set)
212         torch.save(model.state_dict(), model_filename)
213         log_string('saved_model ' + model_filename)
214
215         nb_train_errors = nb_errors(model, train_set)
216
217         log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
218             problem_number,
219             100 * nb_train_errors / train_set.nb_samples,
220             nb_train_errors,
221             train_set.nb_samples)
222         )
223
224         nb_test_errors = nb_errors(model, test_set)
225
226         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
227             problem_number,
228             100 * nb_test_errors / test_set.nb_samples,
229             nb_test_errors,
230             test_set.nb_samples)
231         )
232
233 ######################################################################