import torch
from math import sqrt
-from multiprocessing import Pool, cpu_count
+from torch import multiprocessing
from torch import Tensor
from torch.autograd import Variable
######################################################################
def generate_one_batch(s):
- problem_number, batch_size, cuda, random_seed = s
+ problem_number, batch_size, random_seed = s
svrt.seed(random_seed)
target = torch.LongTensor(batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
- if cuda:
- input = input.cuda()
- target = target.cuda()
return [ input, target ]
class VignetteSet:
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
for b in range(0, self.nb_batches):
- mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ])
+ mp_args.append( [ problem_number, batch_size, seeds[b] ])
- # self.data = []
- # for b in range(0, self.nb_batches):
- # self.data.append(generate_one_batch(mp_args[b]))
+ self.data = []
+ for b in range(0, self.nb_batches):
+ self.data.append(generate_one_batch(mp_args[b]))
+
+ # Weird thing going on with the multi-processing, waiting for more info
- self.data = Pool(cpu_count()).map(generate_one_batch, mp_args)
+ # pool = multiprocessing.Pool(multiprocessing.cpu_count())
+ # self.data = pool.map(generate_one_batch, mp_args)
acc = 0.0
acc_sq = 0.0
std = sqrt(acc_sq / self.nb_batches - mean * mean)
for b in range(0, self.nb_batches):
self.data[b][0].sub_(mean).div_(std)
+ if cuda:
+ self.data[b][0] = self.data[b][0].cuda()
+ self.data[b][1] = self.data[b][1].cuda()
def get_batch(self, b):
return self.data[b]