X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=vignette_set.py;h=c46beea3b2a809fcd7ff49db2be0e6d0d6bd992e;hp=72880bab74b2041a013f3bb2048ca50d8e54488c;hb=abbbb61852f54e90df6ac5b5f4dcb71d06f88f49;hpb=605697b42bdf62c0d8a6715d43ab40b7446e9af2 diff --git a/vignette_set.py b/vignette_set.py index 72880ba..c46beea 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -32,11 +32,12 @@ import svrt ###################################################################### def generate_one_batch(s): - svrt.seed(s) - target = torch.LongTensor(self.batch_size).bernoulli_(0.5) + problem_number, batch_size, cuda, random_seed = s + svrt.seed(random_seed) + target = torch.LongTensor(batch_size).bernoulli_(0.5) input = svrt.generate_vignettes(problem_number, target) input = input.float().view(input.size(0), 1, input.size(1), input.size(2)) - if self.cuda: + if cuda: input = input.cuda() target = target.cuda() return [ input, target ] @@ -50,13 +51,16 @@ class VignetteSet: self.nb_batches = nb_batches self.nb_samples = self.nb_batches * self.batch_size - seed_list = torch.LongTensor(self.nb_batches).random_().tolist() + seeds = torch.LongTensor(self.nb_batches).random_() + mp_args = [] + for b in range(0, self.nb_batches): + mp_args.append( [ problem_number, batch_size, cuda, seeds[b] ]) # self.data = [] # for b in range(0, self.nb_batches): - # self.data.append(generate_one_batch(seed_list[b])) + # self.data.append(generate_one_batch(mp_args[b])) - self.data = Pool(cpu_count()).map(generate_one_batch, seed_list) + self.data = Pool(cpu_count()).map(generate_one_batch, mp_args) acc = 0.0 acc_sq = 0.0