Added elapsed time in logging.
[pysvrt.git] / vignette_set.py
1
2 #  svrt is the ``Synthetic Visual Reasoning Test'', an image
3 #  generator for evaluating classification performance of machine
4 #  learning systems, humans and primates.
5 #
6 #  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
7 #  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8 #
9 #  This file is part of svrt.
10 #
11 #  svrt is free software: you can redistribute it and/or modify it
12 #  under the terms of the GNU General Public License version 3 as
13 #  published by the Free Software Foundation.
14 #
15 #  svrt is distributed in the hope that it will be useful, but
16 #  WITHOUT ANY WARRANTY; without even the implied warranty of
17 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18 #  General Public License for more details.
19 #
20 #  You should have received a copy of the GNU General Public License
21 #  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22
23 import torch
24 from math import sqrt
25 from torch import multiprocessing
26
27 from torch import Tensor
28 from torch.autograd import Variable
29
30 import svrt
31
32 ######################################################################
33
34 def generate_one_batch(s):
35     problem_number, batch_size, random_seed = s
36     svrt.seed(random_seed)
37     target = torch.LongTensor(batch_size).bernoulli_(0.5)
38     input = svrt.generate_vignettes(problem_number, target)
39     input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
40     return [ input, target ]
41
42 class VignetteSet:
43
44     def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
45         self.cuda = cuda
46         self.batch_size = batch_size
47         self.problem_number = problem_number
48         self.nb_batches = nb_batches
49         self.nb_samples = self.nb_batches * self.batch_size
50
51         seeds = torch.LongTensor(self.nb_batches).random_()
52         mp_args = []
53         for b in range(0, self.nb_batches):
54             mp_args.append( [ problem_number, batch_size, seeds[b] ])
55
56         self.data = []
57         for b in range(0, self.nb_batches):
58             self.data.append(generate_one_batch(mp_args[b]))
59
60         # Weird thing going on with the multi-processing, waiting for more info
61
62         # pool = multiprocessing.Pool(multiprocessing.cpu_count())
63         # self.data = pool.map(generate_one_batch, mp_args)
64
65         acc = 0.0
66         acc_sq = 0.0
67         for b in range(0, self.nb_batches):
68             input = self.data[b][0]
69             acc += input.sum() / input.numel()
70             acc_sq += input.pow(2).sum() /  input.numel()
71
72         mean = acc / self.nb_batches
73         std = sqrt(acc_sq / self.nb_batches - mean * mean)
74         for b in range(0, self.nb_batches):
75             self.data[b][0].sub_(mean).div_(std)
76             if cuda:
77                 self.data[b][0] = self.data[b][0].cuda()
78                 self.data[b][1] = self.data[b][1].cuda()
79
80     def get_batch(self, b):
81         return self.data[b]
82
83 ######################################################################
84
85 class CompressedVignetteSet:
86     def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
87         self.cuda = cuda
88         self.batch_size = batch_size
89         self.problem_number = problem_number
90         self.nb_batches = nb_batches
91         self.nb_samples = self.nb_batches * self.batch_size
92         self.targets = []
93         self.input_storages = []
94
95         acc = 0.0
96         acc_sq = 0.0
97         for b in range(0, self.nb_batches):
98             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
99             input = svrt.generate_vignettes(problem_number, target)
100             acc += input.float().sum() / input.numel()
101             acc_sq += input.float().pow(2).sum() /  input.numel()
102             self.targets.append(target)
103             self.input_storages.append(svrt.compress(input.storage()))
104
105         self.mean = acc / self.nb_batches
106         self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
107
108     def get_batch(self, b):
109         input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
110         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
111         target = self.targets[b]
112
113         if self.cuda:
114             input = input.cuda()
115             target = target.cuda()
116
117         return input, target
118
119 ######################################################################