Trying to make multiprocessing and cuda friends with each other.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 08:39:11 +0000 (10:39 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 08:39:11 +0000 (10:39 +0200)
vignette_set.py

index c46beea..19a6f33 100755 (executable)
@@ -22,7 +22,7 @@
 
 import torch
 from math import sqrt
-from multiprocessing import Pool, cpu_count
+from torch.multiprocessing import Pool, cpu_count
 
 from torch import Tensor
 from torch.autograd import Variable
@@ -32,14 +32,11 @@ import svrt
 ######################################################################
 
 def generate_one_batch(s):
-    problem_number, batch_size, cuda, random_seed = s
+    problem_number, batch_size, random_seed = s
     svrt.seed(random_seed)
     target = torch.LongTensor(batch_size).bernoulli_(0.5)
     input = svrt.generate_vignettes(problem_number, target)
     input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
-    if cuda:
-        input = input.cuda()
-        target = target.cuda()
     return [ input, target ]
 
 class VignetteSet:
@@ -54,7 +51,7 @@ class VignetteSet:
         seeds = torch.LongTensor(self.nb_batches).random_()
         mp_args = []
         for b in range(0, self.nb_batches):
-            mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ])
+            mp_args.append( [ problem_number, batch_size, seeds[b] ])
 
         # self.data = []
         # for b in range(0, self.nb_batches):
@@ -73,6 +70,9 @@ class VignetteSet:
         std = sqrt(acc_sq / self.nb_batches - mean * mean)
         for b in range(0, self.nb_batches):
             self.data[b][0].sub_(mean).div_(std)
+            if cuda:
+                self.data[b][0] = self.data[b][0].cuda()
+                self.data[b][1] = self.data[b][1].cuda()
 
     def get_batch(self, b):
         return self.data[b]