Make the name of the saved model more explicit.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 import math
27
28 from colorama import Fore, Back, Style
29
30 import torch
31
32 from torch import optim
33 from torch import FloatTensor as Tensor
34 from torch.autograd import Variable
35 from torch import nn
36 from torch.nn import functional as fn
37 from torchvision import datasets, transforms, utils
38
39 from vignette_set import VignetteSet, CompressedVignetteSet
40
41 ######################################################################
42
43 parser = argparse.ArgumentParser(
44     description = 'Simple convnet test on the SVRT.',
45     formatter_class = argparse.ArgumentDefaultsHelpFormatter
46 )
47
48 parser.add_argument('--nb_train_batches',
49                     type = int, default = 1000,
50                     help = 'How many samples for train')
51
52 parser.add_argument('--nb_test_batches',
53                     type = int, default = 100,
54                     help = 'How many samples for test')
55
56 parser.add_argument('--nb_epochs',
57                     type = int, default = 50,
58                     help = 'How many training epochs')
59
60 parser.add_argument('--batch_size',
61                     type = int, default = 100,
62                     help = 'Mini-batch size')
63
64 parser.add_argument('--log_file',
65                     type = str, default = 'cnn-svrt.log',
66                     help = 'Log file name')
67
68 parser.add_argument('--compress_vignettes',
69                     action='store_true', default = False,
70                     help = 'Use lossless compression to reduce the memory footprint')
71
72 args = parser.parse_args()
73
74 ######################################################################
75
76 log_file = open(args.log_file, 'w')
77
78 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
79
80 def log_string(s):
81     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
82     log_file.write(s + '\n')
83     log_file.flush()
84     print(s)
85
86 ######################################################################
87
88 # Afroze's ShallowNet
89
90 #                       map size   nb. maps
91 #                     ----------------------
92 #    input                128x128    1
93 # -- conv(21x21 x 6)   -> 108x108    6
94 # -- max(2x2)          -> 54x54      6
95 # -- conv(19x19 x 16)  -> 36x36      16
96 # -- max(2x2)          -> 18x18      16
97 # -- conv(18x18 x 120) -> 1x1        120
98 # -- reshape           -> 120        1
99 # -- full(120x84)      -> 84         1
100 # -- full(84x2)        -> 2          1
101
102 class AfrozeShallowNet(nn.Module):
103     def __init__(self):
104         super(AfrozeShallowNet, self).__init__()
105         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
106         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
107         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
108         self.fc1 = nn.Linear(120, 84)
109         self.fc2 = nn.Linear(84, 2)
110         self.name = 'shallownet'
111
112     def forward(self, x):
113         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
114         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
115         x = fn.relu(self.conv3(x))
116         x = x.view(-1, 120)
117         x = fn.relu(self.fc1(x))
118         x = self.fc2(x)
119         return x
120
121 ######################################################################
122
123 def train_model(model, train_set):
124     batch_size = args.batch_size
125     criterion = nn.CrossEntropyLoss()
126
127     if torch.cuda.is_available():
128         criterion.cuda()
129
130     optimizer = optim.SGD(model.parameters(), lr = 1e-2)
131
132     for e in range(0, args.nb_epochs):
133         acc_loss = 0.0
134         for b in range(0, train_set.nb_batches):
135             input, target = train_set.get_batch(b)
136             output = model.forward(Variable(input))
137             loss = criterion(output, Variable(target))
138             acc_loss = acc_loss + loss.data[0]
139             model.zero_grad()
140             loss.backward()
141             optimizer.step()
142         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss))
143
144     return model
145
146 ######################################################################
147
148 def nb_errors(model, data_set):
149     ne = 0
150     for b in range(0, data_set.nb_batches):
151         input, target = data_set.get_batch(b)
152         output = model.forward(Variable(input))
153         wta_prediction = output.data.max(1)[1].view(-1)
154
155         for i in range(0, data_set.batch_size):
156             if wta_prediction[i] != target[i]:
157                 ne = ne + 1
158
159     return ne
160
161 ######################################################################
162
163 for arg in vars(args):
164     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
165
166 for problem_number in range(1, 24):
167     if args.compress_vignettes:
168         train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
169         test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
170     else:
171         train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
172         test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
173
174     model = AfrozeShallowNet()
175
176     if torch.cuda.is_available():
177         model.cuda()
178
179     nb_parameters = 0
180     for p in model.parameters():
181         nb_parameters += p.numel()
182     log_string('nb_parameters {:d}'.format(nb_parameters))
183
184     model_filename = model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches) + '.param'
185
186     try:
187         model.load_state_dict(torch.load(model_filename))
188         log_string('loaded_model ' + model_filename)
189     except:
190         log_string('training_model')
191         train_model(model, train_set)
192         torch.save(model.state_dict(), model_filename)
193         log_string('saved_model ' + model_filename)
194
195     nb_train_errors = nb_errors(model, train_set)
196
197     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
198         problem_number,
199         100 * nb_train_errors / train_set.nb_samples,
200         nb_train_errors,
201         train_set.nb_samples)
202     )
203
204     nb_test_errors = nb_errors(model, test_set)
205
206     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
207         problem_number,
208         100 * nb_test_errors / test_set.nb_samples,
209         nb_test_errors,
210         test_set.nb_samples)
211     )
212
213 ######################################################################