Added a script to general samples as images in a pytorch-structured ImageFolder.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2018 14:50:23 +0000 (15:50 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2018 14:50:23 +0000 (15:50 +0100)
generate.py [new file with mode: 0755]

diff --git a/generate.py b/generate.py
new file mode 100755 (executable)
index 0000000..7f29683
--- /dev/null
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+
+#  svrt is the ``Synthetic Visual Reasoning Test'', an image
+#  generator for evaluating classification performance of machine
+#  learning systems, humans and primates.
+#
+#  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
+#  Written by Francois Fleuret <francois.fleuret@idiap.ch>
+#
+#  This file is part of svrt.
+#
+#  svrt is free software: you can redistribute it and/or modify it
+#  under the terms of the GNU General Public License version 3 as
+#  published by the Free Software Foundation.
+#
+#  svrt is distributed in the hope that it will be useful, but
+#  WITHOUT ANY WARRANTY; without even the implied warranty of
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+#  General Public License for more details.
+#
+#  You should have received a copy of the GNU General Public License
+#  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
+
+import time
+import argparse
+
+import torch
+import torchvision, os
+
+from torch import optim
+from torch import FloatTensor as Tensor
+from torch.autograd import Variable
+from torch import nn
+from torch.nn import functional as fn
+
+from torchvision import datasets, transforms, utils
+
+import svrt
+
+######################################################################
+# Parsing arguments
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description='SVRT sample generator.',
+    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--nb_samples',
+                    type = int,
+                    default = 1000,
+                    help='How many samples to generate')
+
+parser.add_argument('--problem',
+                    type = int,
+                    default = 1,
+                    help='Problem to generate samples from')
+
+parser.add_argument('--data_dir',
+                    type = str,
+                    default = '',
+                    help='Where to generate the samples')
+
+######################################################################
+
+args = parser.parse_args()
+
+if os.path.isdir(args.data_dir):
+    name = 'problem_{:02d}/class_'.format(args.problem)
+    os.makedirs(args.data_dir + '/' + name + '0', exist_ok = True)
+    os.makedirs(args.data_dir + '/' + name + '1', exist_ok = True)
+else:
+    raise FileNotFoundError('Cannot find ' + args.data_dir)
+
+labels = torch.LongTensor(args.nb_samples).zero_()
+labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+x = svrt.generate_vignettes(args.problem, labels).float()
+
+x.sub_(128).div_(64)
+
+print('MEAN', x.mean(), 'STD', x.std())
+
+for k in range(x.size(0)):
+    filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k)
+    torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename)