Added persistence of the model parameters.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26
27 from colorama import Fore, Back, Style
28
29 import torch
30
31 from torch import optim
32 from torch import FloatTensor as Tensor
33 from torch.autograd import Variable
34 from torch import nn
35 from torch.nn import functional as fn
36 from torchvision import datasets, transforms, utils
37
38 import svrt
39
40 ######################################################################
41
42 parser = argparse.ArgumentParser(
43     description = 'Simple convnet test on the SVRT.',
44     formatter_class = argparse.ArgumentDefaultsHelpFormatter
45 )
46
47 parser.add_argument('--nb_train_samples',
48                     type = int, default = 100000,
49                     help = 'How many samples for train')
50
51 parser.add_argument('--nb_test_samples',
52                     type = int, default = 10000,
53                     help = 'How many samples for test')
54
55 parser.add_argument('--nb_epochs',
56                     type = int, default = 50,
57                     help = 'How many training epochs')
58
59 parser.add_argument('--log_file',
60                     type = str, default = 'cnn-svrt.log',
61                     help = 'Log file name')
62
63 args = parser.parse_args()
64
65 ######################################################################
66
67 log_file = open(args.log_file, 'w')
68
69 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
70
71 def log_string(s):
72     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
73     log_file.write(s + '\n')
74     log_file.flush()
75     print(s)
76
77 ######################################################################
78
79 def generate_set(p, n):
80     target = torch.LongTensor(n).bernoulli_(0.5)
81     t = time.time()
82     input = svrt.generate_vignettes(p, target)
83     t = time.time() - t
84     log_string('data_set_generation {:.02f} sample/s'.format(n / t))
85     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86     return Variable(input), Variable(target)
87
88 ######################################################################
89
90 # Afroze's ShallowNet
91
92 #                       map size   nb. maps
93 #                     ----------------------
94 #    input                128x128    1
95 # -- conv(21x21 x 6)   -> 108x108    6
96 # -- max(2x2)          -> 54x54      6
97 # -- conv(19x19 x 16)  -> 36x36      16
98 # -- max(2x2)          -> 18x18      16
99 # -- conv(18x18 x 120) -> 1x1        120
100 # -- reshape           -> 120        1
101 # -- full(120x84)      -> 84         1
102 # -- full(84x2)        -> 2          1
103
104 class AfrozeShallowNet(nn.Module):
105     def __init__(self):
106         super(AfrozeShallowNet, self).__init__()
107         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
108         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
109         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
110         self.fc1 = nn.Linear(120, 84)
111         self.fc2 = nn.Linear(84, 2)
112
113     def forward(self, x):
114         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
115         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
116         x = fn.relu(self.conv3(x))
117         x = x.view(-1, 120)
118         x = fn.relu(self.fc1(x))
119         x = self.fc2(x)
120         return x
121
122 def train_model(model, train_input, train_target):
123     criterion = nn.CrossEntropyLoss()
124
125     if torch.cuda.is_available():
126         criterion.cuda()
127
128     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
129
130     for k in range(0, args.nb_epochs):
131         acc_loss = 0.0
132         for b in range(0, train_input.size(0), bs):
133             output = model.forward(train_input.narrow(0, b, bs))
134             loss = criterion(output, train_target.narrow(0, b, bs))
135             acc_loss = acc_loss + loss.data[0]
136             model.zero_grad()
137             loss.backward()
138             optimizer.step()
139         log_string('train_loss {:d} {:f}'.format(k, acc_loss))
140
141     return model
142
143 ######################################################################
144
145 def nb_errors(model, data_input, data_target, bs = 100):
146     ne = 0
147
148     for b in range(0, data_input.size(0), bs):
149         output = model.forward(data_input.narrow(0, b, bs))
150         wta_prediction = output.data.max(1)[1].view(-1)
151
152         for i in range(0, bs):
153             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
154                 ne = ne + 1
155
156     return ne
157
158 ######################################################################
159
160 for arg in vars(args):
161     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
162
163 for problem_number in range(1, 24):
164     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
165     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
166     model = AfrozeShallowNet()
167
168     if torch.cuda.is_available():
169         train_input, train_target = train_input.cuda(), train_target.cuda()
170         test_input, test_target = test_input.cuda(), test_target.cuda()
171         model.cuda()
172
173     mu, std = train_input.data.mean(), train_input.data.std()
174     train_input.data.sub_(mu).div_(std)
175     test_input.data.sub_(mu).div_(std)
176
177     nb_parameters = 0
178     for p in model.parameters():
179         nb_parameters += p.numel()
180     log_string('nb_parameters {:d}'.format(nb_parameters))
181
182     model_filename = 'model_' + str(problem_number) + '.param'
183
184     try:
185         model.load_state_dict(torch.load(model_filename))
186         log_string('loaded_model ' + model_filename)
187     except:
188         log_string('training_model')
189         train_model(model, train_input, train_target)
190         torch.save(model.state_dict(), model_filename)
191         log_string('saved_model ' + model_filename)
192
193     nb_train_errors = nb_errors(model, train_input, train_target)
194
195     log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
196         problem_number,
197         100 * nb_train_errors / train_input.size(0),
198         nb_train_errors,
199         train_input.size(0))
200     )
201
202     nb_test_errors = nb_errors(model, test_input, test_target)
203
204     log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
205         problem_number,
206         100 * nb_test_errors / test_input.size(0),
207         nb_test_errors,
208         test_input.size(0))
209     )
210
211 ######################################################################