Replace the numbers of samples by numbers of batches of samples.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26
27 from colorama import Fore, Back, Style
28
29 import torch
30
31 from torch import optim
32 from torch import FloatTensor as Tensor
33 from torch.autograd import Variable
34 from torch import nn
35 from torch.nn import functional as fn
36 from torchvision import datasets, transforms, utils
37
38 import svrt
39
40 ######################################################################
41
42 parser = argparse.ArgumentParser(
43     description = 'Simple convnet test on the SVRT.',
44     formatter_class = argparse.ArgumentDefaultsHelpFormatter
45 )
46
47 parser.add_argument('--nb_train_batches',
48                     type = int, default = 1000,
49                     help = 'How many samples for train')
50
51 parser.add_argument('--nb_test_batches',
52                     type = int, default = 100,
53                     help = 'How many samples for test')
54
55 parser.add_argument('--nb_epochs',
56                     type = int, default = 50,
57                     help = 'How many training epochs')
58
59 parser.add_argument('--batch_size',
60                     type = int, default = 100,
61                     help = 'Mini-batch size')
62
63 parser.add_argument('--log_file',
64                     type = str, default = 'cnn-svrt.log',
65                     help = 'Log file name')
66
67 args = parser.parse_args()
68
69 ######################################################################
70
71 log_file = open(args.log_file, 'w')
72
73 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
74
75 def log_string(s):
76     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
77     log_file.write(s + '\n')
78     log_file.flush()
79     print(s)
80
81 ######################################################################
82
83 def generate_set(p, n):
84     target = torch.LongTensor(n).bernoulli_(0.5)
85     t = time.time()
86     input = svrt.generate_vignettes(p, target)
87     t = time.time() - t
88     log_string('data_set_generation {:.02f} sample/s'.format(n / t))
89     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
90     return Variable(input), Variable(target)
91
92 ######################################################################
93
94 # Afroze's ShallowNet
95
96 #                       map size   nb. maps
97 #                     ----------------------
98 #    input                128x128    1
99 # -- conv(21x21 x 6)   -> 108x108    6
100 # -- max(2x2)          -> 54x54      6
101 # -- conv(19x19 x 16)  -> 36x36      16
102 # -- max(2x2)          -> 18x18      16
103 # -- conv(18x18 x 120) -> 1x1        120
104 # -- reshape           -> 120        1
105 # -- full(120x84)      -> 84         1
106 # -- full(84x2)        -> 2          1
107
108 class AfrozeShallowNet(nn.Module):
109     def __init__(self):
110         super(AfrozeShallowNet, self).__init__()
111         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
112         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
113         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
114         self.fc1 = nn.Linear(120, 84)
115         self.fc2 = nn.Linear(84, 2)
116
117     def forward(self, x):
118         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
119         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
120         x = fn.relu(self.conv3(x))
121         x = x.view(-1, 120)
122         x = fn.relu(self.fc1(x))
123         x = self.fc2(x)
124         return x
125
126 def train_model(model, train_input, train_target):
127     bs = args.batch_size
128     criterion = nn.CrossEntropyLoss()
129
130     if torch.cuda.is_available():
131         criterion.cuda()
132
133     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
134
135     for k in range(0, args.nb_epochs):
136         acc_loss = 0.0
137         for b in range(0, train_input.size(0), bs):
138             output = model.forward(train_input.narrow(0, b, bs))
139             loss = criterion(output, train_target.narrow(0, b, bs))
140             acc_loss = acc_loss + loss.data[0]
141             model.zero_grad()
142             loss.backward()
143             optimizer.step()
144         log_string('train_loss {:d} {:f}'.format(k, acc_loss))
145
146     return model
147
148 ######################################################################
149
150 def nb_errors(model, data_input, data_target):
151     bs = args.batch_size
152
153     ne = 0
154     for b in range(0, data_input.size(0), bs):
155         output = model.forward(data_input.narrow(0, b, bs))
156         wta_prediction = output.data.max(1)[1].view(-1)
157
158         for i in range(0, bs):
159             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
160                 ne = ne + 1
161
162     return ne
163
164 ######################################################################
165
166 for arg in vars(args):
167     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
168
169 for problem_number in range(1, 24):
170     train_input, train_target = generate_set(problem_number,
171                                              args.nb_train_batches * args.batch_size)
172     test_input, test_target = generate_set(problem_number,
173                                            args.nb_test_batches * args.batch_size)
174     model = AfrozeShallowNet()
175
176     if torch.cuda.is_available():
177         train_input, train_target = train_input.cuda(), train_target.cuda()
178         test_input, test_target = test_input.cuda(), test_target.cuda()
179         model.cuda()
180
181     mu, std = train_input.data.mean(), train_input.data.std()
182     train_input.data.sub_(mu).div_(std)
183     test_input.data.sub_(mu).div_(std)
184
185     nb_parameters = 0
186     for p in model.parameters():
187         nb_parameters += p.numel()
188     log_string('nb_parameters {:d}'.format(nb_parameters))
189
190     model_filename = 'model_' + str(problem_number) + '.param'
191
192     try:
193         model.load_state_dict(torch.load(model_filename))
194         log_string('loaded_model ' + model_filename)
195     except:
196         log_string('training_model')
197         train_model(model, train_input, train_target)
198         torch.save(model.state_dict(), model_filename)
199         log_string('saved_model ' + model_filename)
200
201     nb_train_errors = nb_errors(model, train_input, train_target)
202
203     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
204         problem_number,
205         100 * nb_train_errors / train_input.size(0),
206         nb_train_errors,
207         train_input.size(0))
208     )
209
210     nb_test_errors = nb_errors(model, test_input, test_target)
211
212     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
213         problem_number,
214         100 * nb_test_errors / test_input.size(0),
215         nb_test_errors,
216         test_input.size(0))
217     )
218
219 ######################################################################