Minor update.
[pysvrt.git] / generate.py
1 #!/usr/bin/env python
2
3 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
4 #  generator for evaluating classification performance of machine
5 #  learning systems, humans and primates.
6 #
7 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 #
10 #  This file is part of svrt.
11 #
12 #  svrt is free software: you can redistribute it and/or modify it
13 #  under the terms of the GNU General Public License version 3 as
14 #  published by the Free Software Foundation.
15 #
16 #  svrt is distributed in the hope that it will be useful, but
17 #  WITHOUT ANY WARRANTY; without even the implied warranty of
18 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 #  General Public License for more details.
20 #
21 #  You should have received a copy of the GNU General Public License
22 #  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23
24 import time
25 import argparse
26
27 import torch
28 import torchvision, os
29
30 from torch import optim
31 from torch import FloatTensor as Tensor
32 from torch.autograd import Variable
33 from torch import nn
34 from torch.nn import functional as fn
35
36 from torchvision import datasets, transforms, utils
37
38 import svrt
39
40 ######################################################################
41 # Parsing arguments
42 ######################################################################
43
44 parser = argparse.ArgumentParser(
45     description='SVRT sample generator.',
46     formatter_class = argparse.ArgumentDefaultsHelpFormatter
47 )
48
49 parser.add_argument('--nb_samples',
50                     type = int,
51                     default = 1000,
52                     help='How many samples to generate in total')
53
54 parser.add_argument('--batch_size',
55                     type = int,
56                     default = 1000,
57                     help='How many samples to generate at once')
58
59 parser.add_argument('--problem',
60                     type = int,
61                     default = 1,
62                     help='Problem to generate samples from')
63
64 parser.add_argument('--data_dir',
65                     type = str,
66                     default = '',
67                     help='Where to generate the samples')
68
69 ######################################################################
70
71 args = parser.parse_args()
72
73 if os.path.isdir(args.data_dir):
74     name = 'problem_{:02d}/class_'.format(args.problem)
75     os.makedirs(args.data_dir + '/' + name + '0', exist_ok = True)
76     os.makedirs(args.data_dir + '/' + name + '1', exist_ok = True)
77 else:
78     raise FileNotFoundError('Cannot find ' + args.data_dir)
79
80 for n in range(0, args.nb_samples, args.batch_size):
81     print(n, '/', args.nb_samples)
82     labels = torch.LongTensor(min(args.batch_size, args.nb_samples - n)).zero_()
83     labels.narrow(0, 0, labels.size(0)//2).fill_(1)
84     x = svrt.generate_vignettes(args.problem, labels).float()
85     x.sub_(128).div_(64)
86     for k in range(x.size(0)):
87         filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:07d}.png'.format(args.problem, labels[k], k + n)
88         torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)