c1fe3acd7a34b5a341c6a261e6ee5eda1da59d4e
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25
26 import torch
27
28 from torch import optim
29 from torch import FloatTensor as Tensor
30 from torch.autograd import Variable
31 from torch import nn
32 from torch.nn import functional as fn
33 from torchvision import datasets, transforms, utils
34
35 import svrt
36
37 ######################################################################
38
39 def generate_set(p, n):
40     target = torch.LongTensor(n).bernoulli_(0.5)
41     input = svrt.generate_vignettes(p, target)
42     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
43     return Variable(input), Variable(target)
44
45 ######################################################################
46
47 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
48
49 class Net(nn.Module):
50     def __init__(self):
51         super(Net, self).__init__()
52         self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
53         self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
54         self.fc1 = nn.Linear(500, 100)
55         self.fc2 = nn.Linear(100, 2)
56
57     def forward(self, x):
58         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
59         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
60         x = x.view(-1, 500)
61         x = fn.relu(self.fc1(x))
62         x = self.fc2(x)
63         return x
64
65 def train_model(train_input, train_target):
66     model, criterion = Net(), nn.CrossEntropyLoss()
67
68     if torch.cuda.is_available():
69         model.cuda()
70         criterion.cuda()
71
72     nb_epochs = 25
73     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
74
75     for k in range(0, nb_epochs):
76         for b in range(0, nb_train_samples, bs):
77             output = model.forward(train_input.narrow(0, b, bs))
78             loss = criterion(output, train_target.narrow(0, b, bs))
79             model.zero_grad()
80             loss.backward()
81             optimizer.step()
82
83     return model
84
85 ######################################################################
86
87 def print_test_error(model, test_input, test_target):
88     bs = 100
89     nb_test_errors = 0
90
91     for b in range(0, nb_test_samples, bs):
92         output = model.forward(test_input.narrow(0, b, bs))
93         wta_prediction = output.data.max(1)[1].view(-1)
94
95         for i in range(0, bs):
96             if wta_prediction[i] != test_target.narrow(0, b, bs).data[i]:
97                 nb_test_errors = nb_test_errors + 1
98
99     print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
100         100 * nb_test_errors / nb_test_samples,
101         nb_test_errors,
102         nb_test_samples)
103     )
104
105 ######################################################################
106
107 nb_train_samples = 100000
108 nb_test_samples = 10000
109
110 for p in range(1, 24):
111     print('-- PROBLEM #{:d} --'.format(p))
112
113     t1 = time.time()
114     train_input, train_target = generate_set(p, nb_train_samples)
115     test_input, test_target = generate_set(p, nb_test_samples)
116
117     if torch.cuda.is_available():
118         train_input, train_target = train_input.cuda(), train_target.cuda()
119         test_input, test_target = test_input.cuda(), test_target.cuda()
120
121     mu, std = train_input.data.mean(), train_input.data.std()
122     train_input.data.sub_(mu).div_(std)
123     test_input.data.sub_(mu).div_(std)
124
125     t2 = time.time()
126     print('[data generation {:.02f}s]'.format(t2 - t1))
127     model = train_model(train_input, train_target)
128
129     t3 = time.time()
130     print('[train {:.02f}s]'.format(t3 - t2))
131     print_test_error(model, test_input, test_target)
132
133     t4 = time.time()
134
135     print('[test {:.02f}s]'.format(t4 - t3))
136     print()
137
138 ######################################################################