Made VignetteSet.__init__ multi-proc.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 07:43:52 +0000 (09:43 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 07:43:52 +0000 (09:43 +0200)
vignette_set.py

index 695fed3..72880ba 100755 (executable)
@@ -22,6 +22,7 @@
 
 import torch
 from math import sqrt
+from multiprocessing import Pool, cpu_count
 
 from torch import Tensor
 from torch.autograd import Variable
@@ -30,38 +31,47 @@ import svrt
 
 ######################################################################
 
+def generate_one_batch(s):
+    svrt.seed(s)
+    target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
+    input = svrt.generate_vignettes(problem_number, target)
+    input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
+    if self.cuda:
+        input = input.cuda()
+        target = target.cuda()
+    return [ input, target ]
+
 class VignetteSet:
+
     def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
         self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
         self.nb_batches = nb_batches
         self.nb_samples = self.nb_batches * self.batch_size
-        self.targets = []
-        self.inputs = []
+
+        seed_list = torch.LongTensor(self.nb_batches).random_().tolist()
+
+        # self.data = []
+        # for b in range(0, self.nb_batches):
+            # self.data.append(generate_one_batch(seed_list[b]))
+
+        self.data = Pool(cpu_count()).map(generate_one_batch, seed_list)
 
         acc = 0.0
         acc_sq = 0.0
-
         for b in range(0, self.nb_batches):
-            target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
-            input = svrt.generate_vignettes(problem_number, target)
-            input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
-            if self.cuda:
-                input = input.cuda()
-                target = target.cuda()
+            input = self.data[b][0]
             acc += input.sum() / input.numel()
             acc_sq += input.pow(2).sum() /  input.numel()
-            self.targets.append(target)
-            self.inputs.append(input)
 
         mean = acc / self.nb_batches
         std = sqrt(acc_sq / self.nb_batches - mean * mean)
         for b in range(0, self.nb_batches):
-            self.inputs[b].sub_(mean).div_(std)
+            self.data[b][0].sub_(mean).div_(std)
 
     def get_batch(self, b):
-        return self.inputs[b], self.targets[b]
+        return self.data[b]
 
 ######################################################################