Changed the network structure to Afroze's ShallowNet.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26 from colorama import Fore, Back, Style
27
28 import torch
29
30 from torch import optim
31 from torch import FloatTensor as Tensor
32 from torch.autograd import Variable
33 from torch import nn
34 from torch.nn import functional as fn
35 from torchvision import datasets, transforms, utils
36
37 import svrt
38
39 ######################################################################
40
41 parser = argparse.ArgumentParser(
42     description = 'Simple convnet test on the SVRT.',
43     formatter_class = argparse.ArgumentDefaultsHelpFormatter
44 )
45
46 parser.add_argument('--nb_train_samples',
47                     type = int, default = 100000,
48                     help = 'How many samples for train')
49
50 parser.add_argument('--nb_test_samples',
51                     type = int, default = 10000,
52                     help = 'How many samples for test')
53
54 parser.add_argument('--nb_epochs',
55                     type = int, default = 25,
56                     help = 'How many training epochs')
57
58 parser.add_argument('--log_file',
59                     type = str, default = 'cnn-svrt.log',
60                     help = 'Log file name')
61
62 args = parser.parse_args()
63
64 ######################################################################
65
66 log_file = open(args.log_file, 'w')
67
68 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
69
70 def log_string(s):
71     s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
72         str(problem_number) + ' ' + s
73     log_file.write(s + '\n')
74     log_file.flush()
75     print(s)
76
77 ######################################################################
78
79 def generate_set(p, n):
80     target = torch.LongTensor(n).bernoulli_(0.5)
81     t = time.time()
82     input = svrt.generate_vignettes(p, target)
83     t = time.time() - t
84     log_string('DATA_SET_GENERATION {:.02f} sample/s'.format(n / t))
85     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
86     return Variable(input), Variable(target)
87
88 ######################################################################
89
90 # Afroze's ShallowNet
91
92 #                    map size   nb. maps
93 #                  ----------------------
94 #                    128x128    1
95 # -- conv(21x21)  -> 108x108    6
96 # -- max(2x2)     -> 54x54      6
97 # -- conv(19x19)  -> 36x36      16
98 # -- max(2x2)     -> 18x18      16
99 # -- conv(18x18)  -> 1x1        120
100 # -- reshape      -> 120        1
101 # -- full(120x84) -> 84         1
102 # -- full(84x2)   -> 2          1
103
104 class Net(nn.Module):
105     def __init__(self):
106         super(Net, self).__init__()
107         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
108         self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
109         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
110         self.fc1 = nn.Linear(120, 84)
111         self.fc2 = nn.Linear(84, 2)
112
113     def forward(self, x):
114         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
115         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
116         x = fn.relu(self.conv3(x))
117         x = x.view(-1, 120)
118         x = fn.relu(self.fc1(x))
119         x = self.fc2(x)
120         return x
121
122 def train_model(train_input, train_target):
123     model, criterion = Net(), nn.CrossEntropyLoss()
124
125     if torch.cuda.is_available():
126         model.cuda()
127         criterion.cuda()
128
129     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
130
131     for k in range(0, args.nb_epochs):
132         acc_loss = 0.0
133         for b in range(0, train_input.size(0), bs):
134             output = model.forward(train_input.narrow(0, b, bs))
135             loss = criterion(output, train_target.narrow(0, b, bs))
136             acc_loss = acc_loss + loss.data[0]
137             model.zero_grad()
138             loss.backward()
139             optimizer.step()
140         log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
141
142     return model
143
144 ######################################################################
145
146 def nb_errors(model, data_input, data_target, bs = 100):
147     ne = 0
148
149     for b in range(0, data_input.size(0), bs):
150         output = model.forward(data_input.narrow(0, b, bs))
151         wta_prediction = output.data.max(1)[1].view(-1)
152
153         for i in range(0, bs):
154             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
155                 ne = ne + 1
156
157     return ne
158
159 ######################################################################
160
161 for problem_number in range(1, 24):
162     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
163     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
164
165     if torch.cuda.is_available():
166         train_input, train_target = train_input.cuda(), train_target.cuda()
167         test_input, test_target = test_input.cuda(), test_target.cuda()
168
169     mu, std = train_input.data.mean(), train_input.data.std()
170     train_input.data.sub_(mu).div_(std)
171     test_input.data.sub_(mu).div_(std)
172
173     model = train_model(train_input, train_target)
174
175     nb_train_errors = nb_errors(model, train_input, train_target)
176
177     log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
178         100 * nb_train_errors / train_input.size(0),
179         nb_train_errors,
180         train_input.size(0))
181     )
182
183     nb_test_errors = nb_errors(model, test_input, test_target)
184
185     log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
186         100 * nb_test_errors / test_input.size(0),
187         nb_test_errors,
188         test_input.size(0))
189     )
190
191 ######################################################################