3 # svrt is the ``Synthetic Visual Reasoning Test'', an image
4 # generator for evaluating classification performance of machine
5 # learning systems, humans and primates.
7 # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 # Written by Francois Fleuret <francois.fleuret@idiap.ch>
10 # This file is part of svrt.
12 # svrt is free software: you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License version 3 as
14 # published by the Free Software Foundation.
16 # svrt is distributed in the hope that it will be useful, but
17 # WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 # General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with svrt. If not, see <http://www.gnu.org/licenses/>.
32 from colorama import Fore, Back, Style
39 from torch import optim
40 from torch import multiprocessing
41 from torch import FloatTensor as Tensor
42 from torch.autograd import Variable
44 from torch.nn import functional as fn
46 from torchvision import datasets, transforms, utils
52 ######################################################################
54 parser = argparse.ArgumentParser(
55 description = "Convolutional networks for the SVRT. Written by Francois Fleuret, (C) Idiap research institute.",
56 formatter_class = argparse.ArgumentDefaultsHelpFormatter
59 parser.add_argument('--nb_train_samples',
60 type = int, default = 100000)
62 parser.add_argument('--nb_test_samples',
63 type = int, default = 10000)
65 parser.add_argument('--nb_validation_samples',
66 type = int, default = 10000)
68 parser.add_argument('--validation_error_threshold',
69 type = float, default = 0.0,
70 help = 'Early training termination criterion')
72 parser.add_argument('--nb_epochs',
73 type = int, default = 50)
75 parser.add_argument('--batch_size',
76 type = int, default = 100)
78 parser.add_argument('--log_file',
79 type = str, default = 'default.log')
81 parser.add_argument('--nb_exemplar_vignettes',
82 type = int, default = 32)
84 parser.add_argument('--compress_vignettes',
85 type = distutils.util.strtobool, default = 'True',
86 help = 'Use lossless compression to reduce the memory footprint')
88 parser.add_argument('--save_test_mistakes',
89 type = distutils.util.strtobool, default = 'False')
91 parser.add_argument('--model',
92 type = str, default = 'deepnet',
93 help = 'What model to use')
95 parser.add_argument('--test_loaded_models',
96 type = distutils.util.strtobool, default = 'False',
97 help = 'Should we compute the test errors of loaded models')
99 parser.add_argument('--problems',
100 type = str, default = '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23',
101 help = 'What problems to process')
103 args = parser.parse_args()
105 ######################################################################
107 log_file = open(args.log_file, 'a')
109 log_file.write('@@@@@@@@@@@@@@@@@@@ ' + time.ctime() + ' @@@@@@@@@@@@@@@@@@@\n')
113 last_tag_t = time.time()
115 print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
117 # Log and prints the string, with a time stamp. Does not log the
120 def log_string(s, remark = ''):
121 global pred_log_t, last_tag_t
125 if pred_log_t is None:
128 elapsed = '+{:.02f}s'.format(t - pred_log_t)
132 if t > last_tag_t + 3600:
134 print(Fore.RED + time.ctime() + Style.RESET_ALL)
136 log_file.write(re.sub(' ', '_', time.ctime()) + ' ' + elapsed + ' ' + s + '\n')
139 print(Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed \
142 + s + Fore.CYAN + remark \
145 ######################################################################
147 def handler_sigint(signum, frame):
148 log_string('got sigint')
151 def handler_sigterm(signum, frame):
152 log_string('got sigterm')
155 signal.signal(signal.SIGINT, handler_sigint)
156 signal.signal(signal.SIGTERM, handler_sigterm)
158 ######################################################################
160 # Afroze's ShallowNet
163 # ----------------------
165 # -- conv(21x21 x 6) -> 108x108 6
166 # -- max(2x2) -> 54x54 6
167 # -- conv(19x19 x 16) -> 36x36 16
168 # -- max(2x2) -> 18x18 16
169 # -- conv(18x18 x 120) -> 1x1 120
170 # -- reshape -> 120 1
171 # -- full(120x84) -> 84 1
172 # -- full(84x2) -> 2 1
174 class AfrozeShallowNet(nn.Module):
178 super(AfrozeShallowNet, self).__init__()
179 self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
180 self.conv2 = nn.Conv2d(6, 16, kernel_size=19)
181 self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
182 self.fc1 = nn.Linear(120, 84)
183 self.fc2 = nn.Linear(84, 2)
185 def forward(self, x):
186 x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
187 x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=2))
188 x = fn.relu(self.conv3(x))
190 x = fn.relu(self.fc1(x))
194 ######################################################################
198 class AfrozeDeepNet(nn.Module):
203 super(AfrozeDeepNet, self).__init__()
204 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
205 self.conv2 = nn.Conv2d( 32, 96, kernel_size=5, padding=2)
206 self.conv3 = nn.Conv2d( 96, 128, kernel_size=3, padding=1)
207 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
208 self.conv5 = nn.Conv2d(128, 96, kernel_size=3, padding=1)
209 self.fc1 = nn.Linear(1536, 256)
210 self.fc2 = nn.Linear(256, 256)
211 self.fc3 = nn.Linear(256, 2)
213 def forward(self, x):
215 x = fn.max_pool2d(x, kernel_size=2)
219 x = fn.max_pool2d(x, kernel_size=2)
229 x = fn.max_pool2d(x, kernel_size=2)
244 ######################################################################
246 class DeepNet2(nn.Module):
250 super(DeepNet2, self).__init__()
251 self.nb_channels = 512
252 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
253 self.conv2 = nn.Conv2d( 32, self.nb_channels, kernel_size=5, padding=2)
254 self.conv3 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
255 self.conv4 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
256 self.conv5 = nn.Conv2d(self.nb_channels, self.nb_channels, kernel_size=3, padding=1)
257 self.fc1 = nn.Linear(16 * self.nb_channels, 512)
258 self.fc2 = nn.Linear(512, 512)
259 self.fc3 = nn.Linear(512, 2)
261 def forward(self, x):
263 x = fn.max_pool2d(x, kernel_size=2)
267 x = fn.max_pool2d(x, kernel_size=2)
277 x = fn.max_pool2d(x, kernel_size=2)
280 x = x.view(-1, 16 * self.nb_channels)
292 ######################################################################
294 class DeepNet3(nn.Module):
298 super(DeepNet3, self).__init__()
299 self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
300 self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
301 self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
302 self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
303 self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
304 self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
305 self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
306 self.fc1 = nn.Linear(2048, 256)
307 self.fc2 = nn.Linear(256, 256)
308 self.fc3 = nn.Linear(256, 2)
310 def forward(self, x):
312 x = fn.max_pool2d(x, kernel_size=2)
316 x = fn.max_pool2d(x, kernel_size=2)
326 x = fn.max_pool2d(x, kernel_size=2)
347 ######################################################################
349 def nb_errors(model, data_set, mistake_filename_pattern = None):
351 for b in range(0, data_set.nb_batches):
352 input, target = data_set.get_batch(b)
353 output = model.forward(Variable(input))
354 wta_prediction = output.data.max(1)[1].view(-1)
356 for i in range(0, data_set.batch_size):
357 if wta_prediction[i] != target[i]:
359 if mistake_filename_pattern is not None:
360 img = input[i].clone()
363 k = b * data_set.batch_size + i
364 filename = mistake_filename_pattern.format(k, target[i])
365 torchvision.utils.save_image(img, filename)
366 print(Fore.RED + 'Wrote ' + filename + Style.RESET_ALL)
369 ######################################################################
371 def train_model(model, model_filename, train_set, validation_set, nb_epochs_done = 0):
372 batch_size = args.batch_size
373 criterion = nn.CrossEntropyLoss()
375 if torch.cuda.is_available():
378 optimizer = optim.SGD(model.parameters(), lr = 1e-2)
380 start_t = time.time()
382 for e in range(nb_epochs_done, args.nb_epochs):
384 for b in range(0, train_set.nb_batches):
385 input, target = train_set.get_batch(b)
386 output = model.forward(Variable(input))
387 loss = criterion(output, Variable(target))
388 acc_loss = acc_loss + loss.data[0]
392 dt = (time.time() - start_t) / (e + 1)
394 log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
395 ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
397 torch.save([ model.state_dict(), e + 1 ], model_filename)
399 if validation_set is not None:
400 nb_validation_errors = nb_errors(model, validation_set)
402 log_string('validation_error {:.02f}% {:d} {:d}'.format(
403 100 * nb_validation_errors / validation_set.nb_samples,
404 nb_validation_errors,
405 validation_set.nb_samples)
408 if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
409 log_string('below validation_error_threshold')
414 ######################################################################
416 for arg in vars(args):
417 log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
419 ######################################################################
421 def int_to_suffix(n):
422 if n >= 1000000 and n%1000000 == 0:
423 return str(n//1000000) + 'M'
424 elif n >= 1000 and n%1000 == 0:
425 return str(n//1000) + 'K'
429 class vignette_logger():
430 def __init__(self, delay_min = 60):
431 self.start_t = time.time()
432 self.last_t = self.start_t
433 self.delay_min = delay_min
435 def __call__(self, n, m):
437 if t > self.last_t + self.delay_min:
438 dt = (t - self.start_t) / m
439 log_string('sample_generation {:d} / {:d}'.format(
441 n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
445 def save_exemplar_vignettes(data_set, nb, name):
446 n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
448 for k in range(0, nb):
449 b = n[k] // data_set.batch_size
450 m = n[k] % data_set.batch_size
451 i, t = data_set.get_batch(b)
455 if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
456 patchwork[k].copy_(i)
458 torchvision.utils.save_image(patchwork, name)
460 ######################################################################
462 if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
463 print('The number of samples must be a multiple of the batch size.')
466 if args.compress_vignettes:
467 log_string('using_compressed_vignettes')
468 VignetteSet = svrtset.CompressedVignetteSet
470 log_string('using_uncompressed_vignettes')
471 VignetteSet = svrtset.VignetteSet
473 ########################################
475 for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2, DeepNet3 ]:
476 if args.model == m.name:
479 if model_class is None:
480 print('Unknown model ' + args.model)
483 log_string('using model class ' + m.name)
484 ########################################
486 for problem_number in map(int, args.problems.split(',')):
488 log_string('############### problem ' + str(problem_number) + ' ###############')
490 model = model_class()
492 if torch.cuda.is_available(): model.cuda()
494 model_filename = model.name + '_pb:' + \
495 str(problem_number) + '_ns:' + \
496 int_to_suffix(args.nb_train_samples) + '.pth'
499 for p in model.parameters(): nb_parameters += p.numel()
500 log_string('nb_parameters {:d}'.format(nb_parameters))
502 ##################################################
503 # Tries to load the model
506 model_state_dict, nb_epochs_done = torch.load(model_filename)
507 model.load_state_dict(model_state_dict)
508 log_string('loaded_model ' + model_filename)
513 ##################################################
516 if nb_epochs_done < args.nb_epochs:
518 log_string('training_model ' + model_filename)
522 train_set = VignetteSet(problem_number,
523 args.nb_train_samples, args.batch_size,
524 cuda = torch.cuda.is_available(),
525 logger = vignette_logger())
527 log_string('data_generation {:0.2f} samples / s'.format(
528 train_set.nb_samples / (time.time() - t))
531 if args.nb_exemplar_vignettes > 0:
532 save_exemplar_vignettes(train_set, args.nb_exemplar_vignettes,
533 'exemplar_{:d}.png'.format(problem_number))
535 if args.validation_error_threshold > 0.0:
536 validation_set = VignetteSet(problem_number,
537 args.nb_validation_samples, args.batch_size,
538 cuda = torch.cuda.is_available(),
539 logger = vignette_logger())
541 validation_set = None
543 train_model(model, model_filename,
544 train_set, validation_set,
545 nb_epochs_done = nb_epochs_done)
547 log_string('saved_model ' + model_filename)
549 nb_train_errors = nb_errors(model, train_set)
551 log_string('train_error {:d} {:.02f}% {:d} {:d}'.format(
553 100 * nb_train_errors / train_set.nb_samples,
555 train_set.nb_samples)
558 ##################################################
561 if nb_epochs_done < args.nb_epochs or args.test_loaded_models:
565 test_set = VignetteSet(problem_number,
566 args.nb_test_samples, args.batch_size,
567 cuda = torch.cuda.is_available())
569 nb_test_errors = nb_errors(model, test_set,
570 mistake_filename_pattern = 'mistake_{:06d}_{:d}.png')
572 log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(
574 100 * nb_test_errors / test_set.nb_samples,
579 ######################################################################