8b8ec124e0545feb873d059491032c4277159299
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 # Pytorch
31
32 import torch
33
34 from torch import optim
35 from torch import FloatTensor as Tensor
36 from torch.autograd import Variable
37 from torch import nn
38 from torch.nn import functional as fn
39 from torchvision import datasets, transforms, utils
40
41 # SVRT
42
43 from vignette_set import VignetteSet, CompressedVignetteSet
44
45 ######################################################################
46
47 parser = argparse.ArgumentParser(
48     description = 'Simple convnet test on the SVRT.',
49     formatter_class = argparse.ArgumentDefaultsHelpFormatter
50 )
51
52 parser.add_argument('--nb_train_batches',
53                     type = int, default = 1000,
54                     help = 'How many samples for train')
55
56 parser.add_argument('--nb_test_batches',
57                     type = int, default = 100,
58                     help = 'How many samples for test')
59
60 parser.add_argument('--nb_epochs',
61                     type = int, default = 50,
62                     help = 'How many training epochs')
63
64 parser.add_argument('--batch_size',
65                     type = int, default = 100,
66                     help = 'Mini-batch size')
67
68 parser.add_argument('--log_file',
69                     type = str, default = 'cnn-svrt.log',
70                     help = 'Log file name')
71
72 parser.add_argument('--compress_vignettes',
73                     action='store_true', default = False,
74                     help = 'Use lossless compression to reduce the memory footprint')
75
76 args = parser.parse_args()
77
78 ######################################################################
79
80 log_file = open(args.log_file, 'w')
81
82 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
83
84 def log_string(s):
85     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
86     log_file.write(s + '\n')
87     log_file.flush()
88     print(s)
89
90 ######################################################################
91
92 # Afroze's ShallowNet
93
94 #                       map size   nb. maps
95 #                     ----------------------
96 #    input                128x128    1
97 # -- conv(21x21 x 6)   -> 108x108    6
98 # -- max(2x2)          -> 54x54      6
99 # -- conv(19x19 x 16)  -> 36x36      16
100 # -- max(2x2)          -> 18x18      16
101 # -- conv(18x18 x 120) -> 1x1        120
102 # -- reshape           -> 120        1
103 # -- full(120x84)      -> 84         1
104 # -- full(84x2)        -> 2          1
105
106 class AfrozeShallowNet(nn.Module):
107     def __init__(self):
108         super(AfrozeShallowNet, self).__init__()
109         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
110         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
111         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
112         self.fc1 = nn.Linear(120, 84)
113         self.fc2 = nn.Linear(84, 2)
114         self.name = 'shallownet'
115
116     def forward(self, x):
117         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
118         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
119         x = fn.relu(self.conv3(x))
120         x = x.view(-1, 120)
121         x = fn.relu(self.fc1(x))
122         x = self.fc2(x)
123         return x
124
125 ######################################################################
126
127 def train_model(model, train_set):
128     batch_size = args.batch_size
129     criterion = nn.CrossEntropyLoss()
130
131     if torch.cuda.is_available():
132         criterion.cuda()
133
134     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
135
136     for e in range(0, args.nb_epochs):
137         acc_loss = 0.0
138         for b in range(0, train_set.nb_batches):
139             input, target = train_set.get_batch(b)
140             output = model.forward(Variable(input))
141             loss = criterion(output, Variable(target))
142             acc_loss = acc_loss + loss.data[0]
143             model.zero_grad()
144             loss.backward()
145             optimizer.step()
146         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
147
148     return model
149
150 ######################################################################
151
152 def nb_errors(model, data_set):
153     ne = 0
154     for b in range(0, data_set.nb_batches):
155         input, target = data_set.get_batch(b)
156         output = model.forward(Variable(input))
157         wta_prediction = output.data.max(1)[1].view(-1)
158
159         for i in range(0, data_set.batch_size):
160             if wta_prediction[i] != target[i]:
161                 ne = ne + 1
162
163     return ne
164
165 ######################################################################
166
167 for arg in vars(args):
168     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
169
170 for problem_number in range(1, 24):
171     if args.compress_vignettes:
172         train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size,
173                                           cuda=torch.cuda.is_available())
174         test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size,
175                                          cuda=torch.cuda.is_available())
176     else:
177         train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size,
178                                           cuda=torch.cuda.is_available())
179         test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size,
180                                           cuda=torch.cuda.is_available())
181
182     model = AfrozeShallowNet()
183
184     if torch.cuda.is_available():
185         model.cuda()
186
187     nb_parameters = 0
188     for p in model.parameters():
189         nb_parameters += p.numel()
190     log_string('nb_parameters {:d}'.format(nb_parameters))
191
192     model_filename = model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches) + '.param'
193
194     try:
195         model.load_state_dict(torch.load(model_filename))
196         log_string('loaded_model ' + model_filename)
197     except:
198         log_string('training_model')
199         train_model(model, train_set)
200         torch.save(model.state_dict(), model_filename)
201         log_string('saved_model ' + model_filename)
202
203     nb_train_errors = nb_errors(model, train_set)
204
205     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
206         problem_number,
207         100 * nb_train_errors / train_set.nb_samples,
208         nb_train_errors,
209         train_set.nb_samples)
210     )
211
212     nb_test_errors = nb_errors(model, test_set)
213
214     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
215         problem_number,
216         100 * nb_test_errors / test_set.nb_samples,
217         nb_test_errors,
218         test_set.nb_samples)
219     )
220
221 ######################################################################