Clean up + argument parsing + logging into a file.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26
27 import torch
28
29 from torch import optim
30 from torch import FloatTensor as Tensor
31 from torch.autograd import Variable
32 from torch import nn
33 from torch.nn import functional as fn
34 from torchvision import datasets, transforms, utils
35
36 import svrt
37
38 ######################################################################
39
40 parser = argparse.ArgumentParser(
41     description = 'Simple convnet test on the SVRT.',
42     formatter_class = argparse.ArgumentDefaultsHelpFormatter
43 )
44
45 parser.add_argument('--nb_train_samples',
46                     type = int, default = 100000,
47                     help = 'How many samples for train')
48
49 parser.add_argument('--nb_test_samples',
50                     type = int, default = 10000,
51                     help = 'How many samples for test')
52
53 parser.add_argument('--nb_epochs',
54                     type = int, default = 25,
55                     help = 'How many training epochs')
56
57 args = parser.parse_args()
58
59 ######################################################################
60
61 log_file = open('cnn-svrt.log', 'w')
62
63 def log_string(s):
64     s = time.ctime() + ' ' + str(problem_number) + ' | ' + s
65     log_file.write(s + '\n')
66     log_file.flush()
67     print(s)
68
69 ######################################################################
70
71 def generate_set(p, n):
72     target = torch.LongTensor(n).bernoulli_(0.5)
73     input = svrt.generate_vignettes(p, target)
74     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
75     return Variable(input), Variable(target)
76
77 ######################################################################
78
79 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
80
81 class Net(nn.Module):
82     def __init__(self):
83         super(Net, self).__init__()
84         self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
85         self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
86         self.fc1 = nn.Linear(500, 100)
87         self.fc2 = nn.Linear(100, 2)
88
89     def forward(self, x):
90         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
91         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
92         x = x.view(-1, 500)
93         x = fn.relu(self.fc1(x))
94         x = self.fc2(x)
95         return x
96
97 def train_model(train_input, train_target):
98     model, criterion = Net(), nn.CrossEntropyLoss()
99
100     if torch.cuda.is_available():
101         model.cuda()
102         criterion.cuda()
103
104     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
105
106     for k in range(0, args.nb_epochs):
107         acc_loss = 0.0
108         for b in range(0, train_input.size(0), bs):
109             output = model.forward(train_input.narrow(0, b, bs))
110             loss = criterion(output, train_target.narrow(0, b, bs))
111             acc_loss = acc_loss + loss.data[0]
112             model.zero_grad()
113             loss.backward()
114             optimizer.step()
115         log_string('TRAIN_LOSS {:d} {:f}'.format(k, acc_loss))
116
117     return model
118
119 ######################################################################
120
121 def nb_errors(model, data_input, data_target, bs = 100):
122     ne = 0
123
124     for b in range(0, data_input.size(0), bs):
125         output = model.forward(data_input.narrow(0, b, bs))
126         wta_prediction = output.data.max(1)[1].view(-1)
127
128         for i in range(0, bs):
129             if wta_prediction[i] != data_target.narrow(0, b, bs).data[i]:
130                 ne = ne + 1
131
132     return ne
133
134 ######################################################################
135
136 for problem_number in range(1, 24):
137     train_input, train_target = generate_set(problem_number, args.nb_train_samples)
138     test_input, test_target = generate_set(problem_number, args.nb_test_samples)
139
140     if torch.cuda.is_available():
141         train_input, train_target = train_input.cuda(), train_target.cuda()
142         test_input, test_target = test_input.cuda(), test_target.cuda()
143
144     mu, std = train_input.data.mean(), train_input.data.std()
145     train_input.data.sub_(mu).div_(std)
146     test_input.data.sub_(mu).div_(std)
147
148     model = train_model(train_input, train_target)
149
150     nb_test_errors = nb_errors(model, test_input, test_target)
151
152     log_string('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
153         100 * nb_test_errors / test_input.size(0),
154         nb_test_errors,
155         test_input.size(0))
156     )
157
158 ######################################################################