Update.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 import torch
31
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
35 from torch import nn
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
38
39 from vignette_set import VignetteSet, CompressedVignetteSet
40
41 ######################################################################
42
43 parser = argparse.ArgumentParser(
44     description = 'Simple convnet test on the SVRT.',
45     formatter_class = argparse.ArgumentDefaultsHelpFormatter
46 )
47
48 parser.add_argument('--nb_train_batches',
49                     type = int, default = 1000,
50                     help = 'How many samples for train')
51
52 parser.add_argument('--nb_test_batches',
53                     type = int, default = 100,
54                     help = 'How many samples for test')
55
56 parser.add_argument('--nb_epochs',
57                     type = int, default = 50,
58                     help = 'How many training epochs')
59
60 parser.add_argument('--batch_size',
61                     type = int, default = 100,
62                     help = 'Mini-batch size')
63
64 parser.add_argument('--log_file',
65                     type = str, default = 'cnn-svrt.log',
66                     help = 'Log file name')
67
68 parser.add_argument('--compress_vignettes',
69                     action='store_true', default = False,
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 args = parser.parse_args()
73
74 ######################################################################
75
76 log_file = open(args.log_file, 'w')
77
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
79
80 def log_string(s):
81     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82     log_file.write(s + '\n')
83     log_file.flush()
84     print(s)
85
86 ######################################################################
87
88 # Afroze's ShallowNet
89
90 #                       map size   nb. maps
91 #                     ----------------------
92 #    input                128x128    1
93 # -- conv(21x21 x 6)   -> 108x108    6
94 # -- max(2x2)          -> 54x54      6
95 # -- conv(19x19 x 16)  -> 36x36      16
96 # -- max(2x2)          -> 18x18      16
97 # -- conv(18x18 x 120) -> 1x1        120
98 # -- reshape           -> 120        1
99 # -- full(120x84)      -> 84         1
100 # -- full(84x2)        -> 2          1
101
102 class AfrozeShallowNet(nn.Module):
103     def __init__(self):
104         super(AfrozeShallowNet, self).__init__()
105         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
106         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
107         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
108         self.fc1 = nn.Linear(120, 84)
109         self.fc2 = nn.Linear(84, 2)
110
111     def forward(self, x):
112         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
113         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
114         x = fn.relu(self.conv3(x))
115         x = x.view(-1, 120)
116         x = fn.relu(self.fc1(x))
117         x = self.fc2(x)
118         return x
119
120 def train_model(model, train_set):
121     batch_size = args.batch_size
122     criterion = nn.CrossEntropyLoss()
123
124     if torch.cuda.is_available():
125         criterion.cuda()
126
127     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
128
129     for e in range(0, args.nb_epochs):
130         acc_loss = 0.0
131         for b in range(0, train_set.nb_batches):
132             input, target = train_set.get_batch(b)
133             output = model.forward(Variable(input))
134             loss = criterion(output, Variable(target))
135             acc_loss = acc_loss + loss.data[0]
136             model.zero_grad()
137             loss.backward()
138             optimizer.step()
139         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
140
141     return model
142
143 ######################################################################
144
145 def nb_errors(model, data_set):
146     ne = 0
147     for b in range(0, data_set.nb_batches):
148         input, target = data_set.get_batch(b)
149         output = model.forward(Variable(input))
150         wta_prediction = output.data.max(1)[1].view(-1)
151
152         for i in range(0, data_set.batch_size):
153             if wta_prediction[i] != target[i]:
154                 ne = ne + 1
155
156     return ne
157
158 ######################################################################
159
160 for arg in vars(args):
161     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
162
163 for problem_number in range(1, 24):
164     if args.compress_vignettes:
165         train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
166         test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
167     else:
168         train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
169         test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
170
171     model = AfrozeShallowNet()
172
173     if torch.cuda.is_available():
174         model.cuda()
175
176     nb_parameters = 0
177     for p in model.parameters():
178         nb_parameters += p.numel()
179     log_string('nb_parameters {:d}'.format(nb_parameters))
180
181     model_filename = 'model_' + str(problem_number) + '.param'
182
183     try:
184         model.load_state_dict(torch.load(model_filename))
185         log_string('loaded_model ' + model_filename)
186     except:
187         log_string('training_model')
188         train_model(model, train_set)
189         torch.save(model.state_dict(), model_filename)
190         log_string('saved_model ' + model_filename)
191
192     nb_train_errors = nb_errors(model, train_set)
193
194     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
195         problem_number,
196         100 * nb_train_errors / train_set.nb_samples,
197         nb_train_errors,
198         train_set.nb_samples)
199     )
200
201     nb_test_errors = nb_errors(model, test_set)
202
203     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
204         problem_number,
205         100 * nb_test_errors / test_set.nb_samples,
206         nb_test_errors,
207         test_set.nb_samples)
208     )
209
210 ######################################################################