Initial commit.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python-for-pytorch
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25
26 import torch
27
28 from torch import optim
29 from torch import FloatTensor as Tensor
30 from torch.autograd import Variable
31 from torch import nn
32 from torch.nn import functional as fn
33 from torchvision import datasets, transforms, utils
34
35 import svrt
36
37 ######################################################################
38
39 def generate_set(p, n):
40     target = torch.LongTensor(n).bernoulli_(0.5)
41     input = svrt.generate_vignettes(p, target)
42     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
43     return Variable(input), Variable(target)
44
45 ######################################################################
46
47 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
48
49 class Net(nn.Module):
50     def __init__(self):
51         super(Net, self).__init__()
52         self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
53         self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
54         self.fc1 = nn.Linear(500, 100)
55         self.fc2 = nn.Linear(100, 2)
56
57     def forward(self, x):
58         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
59         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
60         x = x.view(-1, 500)
61         x = fn.relu(self.fc1(x))
62         x = self.fc2(x)
63         return x
64
65 def train_model(train_input, train_target):
66     model, criterion = Net(), nn.CrossEntropyLoss()
67
68     if torch.cuda.is_available():
69         model.cuda()
70         criterion.cuda()
71
72     nb_epochs = 25
73     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
74
75     for k in range(0, nb_epochs):
76         for b in range(0, nb_train_samples, bs):
77             output = model.forward(train_input.narrow(0, b, bs))
78             loss = criterion(output, train_target.narrow(0, b, bs))
79             model.zero_grad()
80             loss.backward()
81             optimizer.step()
82
83     return model
84
85 ######################################################################
86
87 def print_test_error(model, test_input, test_target):
88     bs = 100
89     nb_test_errors = 0
90
91     for b in range(0, nb_test_samples, bs):
92         output = model.forward(test_input.narrow(0, b, bs))
93         _, wta = torch.max(output.data, 1)
94
95         for i in range(0, bs):
96             if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
97                 nb_test_errors = nb_test_errors + 1
98
99     print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
100         100 * nb_test_errors / nb_test_samples,
101         nb_test_errors,
102         nb_test_samples)
103     )
104
105 ######################################################################
106
107 nb_train_samples = 100000
108 nb_test_samples = 10000
109
110 for p in range(1, 24):
111     print('-- PROBLEM #{:d} --'.format(p))
112
113     t1 = time.time()
114     train_input, train_target = generate_set(p, nb_train_samples)
115     test_input, test_target = generate_set(p, nb_test_samples)
116     if torch.cuda.is_available():
117         train_input, train_target = train_input.cuda(), train_target.cuda()
118         test_input, test_target = test_input.cuda(), test_target.cuda()
119
120     mu, std = train_input.data.mean(), train_input.data.std()
121     train_input.data.sub_(mu).div_(std)
122     test_input.data.sub_(mu).div_(std)
123
124     t2 = time.time()
125     print('[data generation {:.02f}s]'.format(t2 - t1))
126     model = train_model(train_input, train_target)
127
128     t3 = time.time()
129     print('[train {:.02f}s]'.format(t3 - t2))
130     print_test_error(model, test_input, test_target)
131
132     t4 = time.time()
133
134     print('[test {:.02f}s]'.format(t4 - t3))
135     print()
136
137 ######################################################################