3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with selector. If not, see <http://www.gnu.org/licenses/>.
29 from colorama import Fore, Back, Style
35 from torch import optim
36 from torch import FloatTensor as Tensor
37 from torch.autograd import Variable
39 from torch.nn import functional as fn
40 from torchvision import datasets, transforms, utils
46 ######################################################################
48 parser = argparse.ArgumentParser(
49 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
50 formatter_class = argparse.ArgumentDefaultsHelpFormatter
53 parser.add_argument('--nb_train_samples',
54 type = int, default = 100000)
56 parser.add_argument('--nb_test_samples',
57 type = int, default = 10000)
59 parser.add_argument('--nb_epochs',
60 type = int, default = 50)
62 parser.add_argument('--batch_size',
63 type = int, default = 100)
65 parser.add_argument('--log_file',
66 type = str, default = 'default.log')
68 parser.add_argument('--compress_vignettes',
69 type = distutils.util.strtobool, default = 'True',
70 help = 'Use lossless compression to reduce the memory footprint')
72 parser.add_argument('--deep_model',
73 type = distutils.util.strtobool, default = 'True',
74 help = 'Use Afroze\'s Alexnet-like deep model')
76 parser.add_argument('--test_loaded_models',
77 type = distutils.util.strtobool, default = 'False',
78 help = 'Should we compute the test errors of loaded models')
80 args = parser.parse_args()
82 ######################################################################
84 log_file = open(args.log_file, 'w')
87 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
89 # Log and prints the string, with a time stamp. Does not log the
91 def log_string(s, remark = ''):
96 if pred_log_t is None:
99 elapsed = '+{:.02f}s'.format(t - pred_log_t)
103 log_file.write('[' + time.ctime() + '] ' + elapsed + ' ' + s + '\n')
106 print(Fore.BLUE + '[' + time.ctime() + '] ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s + Fore.CYAN + remark + Style.RESET_ALL)
108 ######################################################################
110 # Afroze's ShallowNet
113 # ----------------------
115 # -- conv(21x21 x 6) -> 108x108 6
116 # -- max(2x2) -> 54x54 6
117 # -- conv(19x19 x 16) -> 36x36 16
118 # -- max(2x2) -> 18x18 16
119 # -- conv(18x18 x 120) -> 1x1 120
120 # -- reshape -> 120 1
121 # -- full(120x84) -> 84 1
122 # -- full(84x2) -> 2 1
124 class AfrozeShallowNet(nn.Module):
126 super(AfrozeShallowNet, self).__init__()
127 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
128 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
129 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
130 self.fc1 = nn.Linear(120, 84)
131 self.fc2 = nn.Linear(84, 2)
132 self.name = 'shallownet'
134 def forward(self, x):
135 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
136 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
137 x = fn.relu(self.conv3(x))
139 x = fn.relu(self.fc1(x))
143 ######################################################################
148 # ----------------------
150 # -- conv(21x21 x 32 stride=4) -> 28x28 32
151 # -- max(2x2) -> 14x14 6
152 # -- conv(7x7 x 96) -> 8x8 16
153 # -- max(2x2) -> 4x4 16
154 # -- conv(5x5 x 96) -> 26x36 16
155 # -- conv(3x3 x 128) -> 36x36 16
156 # -- conv(3x3 x 128) -> 36x36 16
158 # -- conv(5x5 x 120) -> 1x1 120
159 # -- reshape -> 120 1
160 # -- full(3x84) -> 84 1
161 # -- full(84x2) -> 2 1
163 class AfrozeDeepNet(nn.Module):
165 super(AfrozeDeepNet, self).__init__()
166 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
167 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
168 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
169 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
170 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
171 self.fc1 = nn.Linear(1536, 256)
172 self.fc2 = nn.Linear(256, 256)
173 self.fc3 = nn.Linear(256, 2)
174 self.name = 'deepnet'
176 def forward(self, x):
178 x = fn.max_pool2d(x, kernel_size=2)
182 x = fn.max_pool2d(x, kernel_size=2)
192 x = fn.max_pool2d(x, kernel_size=2)
207 ######################################################################
209 def train_model(model, train_set):
210 batch_size = args.batch_size
211 criterion = nn.CrossEntropyLoss()
213 if torch.cuda.is_available():
216 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
218 start_t = time.time()
220 for e in range(0, args.nb_epochs):
222 for b in range(0, train_set.nb_batches):
223 input, target = train_set.get_batch(b)
224 output = model.forward(Variable(input))
225 loss = criterion(output, Variable(target))
226 acc_loss = acc_loss + loss.data[0]
230 dt = (time.time() - start_t) / (e + 1)
231 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
232 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
236 ######################################################################
238 def nb_errors(model, data_set):
240 for b in range(0, data_set.nb_batches):
241 input, target = data_set.get_batch(b)
242 output = model.forward(Variable(input))
243 wta_prediction = output.data.max(1)[1].view(-1)
245 for i in range(0, data_set.batch_size):
246 if wta_prediction[i] != target[i]:
251 ######################################################################
253 for arg in vars(args):
254 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
256 ######################################################################
258 def int_to_suffix(n):
259 if n >= 1000000 and n%1000000 == 0:
260 return str(n//1000000) + 'M'
261 elif n >= 1000 and n%1000 == 0:
262 return str(n//1000) + 'K'
266 ######################################################################
268 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
269 print('The number of samples must be a multiple of the batch size.')
272 if args.compress_vignettes:
273 log_string('using_compressed_vignettes')
274 VignetteSet = vignette_set.CompressedVignetteSet
276 log_string('using_uncompressed_vignettes')
277 VignetteSet = vignette_set.VignetteSet
279 for problem_number in range(1, 24):
281 log_string('############### problem ' + str(problem_number) + ' ###############')
284 model = AfrozeDeepNet()
286 model = AfrozeShallowNet()
288 if torch.cuda.is_available(): model.cuda()
290 model_filename = model.name + '_pb:' + \
291 str(problem_number) + '_ns:' + \
292 int_to_suffix(args.nb_train_samples) + '.param'
295 for p in model.parameters(): nb_parameters += p.numel()
296 log_string('nb_parameters {:d}'.format(nb_parameters))
298 ##################################################
299 # Tries to load the model
301 need_to_train = False
303 model.load_state_dict(torch.load(model_filename))
304 log_string('loaded_model ' + model_filename)
308 ##################################################
313 log_string('training_model ' + model_filename)
317 train_set = VignetteSet(problem_number,
318 args.nb_train_samples, args.batch_size,
319 cuda = torch.cuda.is_available())
321 log_string('data_generation {:0.2f} samples / s'.format(
322 train_set.nb_samples / (time.time() - t))
325 train_model(model, train_set)
326 torch.save(model.state_dict(), model_filename)
327 log_string('saved_model ' + model_filename)
329 nb_train_errors = nb_errors(model, train_set)
331 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
333 100 * nb_train_errors / train_set.nb_samples,
335 train_set.nb_samples)
338 ##################################################
341 if need_to_train or args.test_loaded_models:
345 test_set = VignetteSet(problem_number,
346 args.nb_test_samples, args.batch_size,
347 cuda = torch.cuda.is_available())
349 log_string('data_generation {:0.2f} samples / s'.format(
350 test_set.nb_samples / (time.time() - t))
353 nb_test_errors = nb_errors(model, test_set)
355 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
357 100 * nb_test_errors / test_set.nb_samples,
362 ######################################################################