Added svrt.seed(long).
[pysvrt.git] / vignette_set.py
index ea52159..695fed3 100755 (executable)
@@ -31,7 +31,8 @@ import svrt
 ######################################################################
 
 class VignetteSet:
-    def __init__(self, problem_number, nb_batches, batch_size):
+    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+        self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
         self.nb_batches = nb_batches
@@ -46,7 +47,7 @@ class VignetteSet:
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
             input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
-            if torch.cuda.is_available():
+            if self.cuda:
                 input = input.cuda()
                 target = target.cuda()
             acc += input.sum() / input.numel()
@@ -65,7 +66,8 @@ class VignetteSet:
 ######################################################################
 
 class CompressedVignetteSet:
-    def __init__(self, problem_number, nb_batches, batch_size):
+    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+        self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
         self.nb_batches = nb_batches
@@ -84,14 +86,14 @@ class CompressedVignetteSet:
             self.input_storages.append(svrt.compress(input.storage()))
 
         self.mean = acc / self.nb_batches
-        self.std = math.sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
+        self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
 
     def get_batch(self, b):
         input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
         target = self.targets[b]
 
-        if torch.cuda.is_available():
+        if self.cuda:
             input = input.cuda()
             target = target.cuda()