-def generate_set(p, n):
- target = torch.LongTensor(n).bernoulli_(0.5)
- input = svrt.generate_vignettes(p, target)
- input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
- return Variable(input), Variable(target)
+parser = argparse.ArgumentParser(
+ description = 'Simple convnet test on the SVRT.',
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--nb_train_batches',
+ type = int, default = 1000,
+ help = 'How many samples for train')
+
+parser.add_argument('--nb_test_batches',
+ type = int, default = 100,
+ help = 'How many samples for test')
+
+parser.add_argument('--nb_epochs',
+ type = int, default = 50,
+ help = 'How many training epochs')
+
+parser.add_argument('--batch_size',
+ type = int, default = 100,
+ help = 'Mini-batch size')
+
+parser.add_argument('--log_file',
+ type = str, default = 'cnn-svrt.log',
+ help = 'Log file name')
+
+parser.add_argument('--compress_vignettes',
+ action='store_true', default = False,
+ help = 'Use lossless compression to reduce the memory footprint')
+
+args = parser.parse_args()