From e2368847af8e2eb5d6dda88b3318b64ec8637667 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 16 Jun 2017 10:39:11 +0200 Subject: [PATCH] Trying to make multiprocessing and cuda friends with each other. --- vignette_set.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vignette_set.py b/vignette_set.py index c46beea..19a6f33 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -22,7 +22,7 @@ import torch from math import sqrt -from multiprocessing import Pool, cpu_count +from torch.multiprocessing import Pool, cpu_count from torch import Tensor from torch.autograd import Variable @@ -32,14 +32,11 @@ import svrt ###################################################################### def generate_one_batch(s): - problem_number, batch_size, cuda, random_seed = s + problem_number, batch_size, random_seed = s svrt.seed(random_seed) target = torch.LongTensor(batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if cuda: - input = input.cuda() - target = target.cuda() return [ input, target ] class VignetteSet: @@ -54,7 +51,7 @@ class VignetteSet: seeds = torch.LongTensor(self.nb_batches).random_() mp_args = [] for b in range(0, self.nb_batches): - mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ]) + mp_args.append( [ problem_number, batch_size, seeds[b] ]) # self.data = [] # for b in range(0, self.nb_batches): @@ -73,6 +70,9 @@ class VignetteSet: std = sqrt(acc_sq / self.nb_batches - mean * mean) for b in range(0, self.nb_batches): self.data[b][0].sub_(mean).div_(std) + if cuda: + self.data[b][0] = self.data[b][0].cuda() + self.data[b][1] = self.data[b][1].cuda() def get_batch(self, b): return self.data[b] -- 2.20.1