X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=vignette_set.py;h=aef23d875fd285f50f66ca1d14f130fbc30a16d0;hb=91b12f8980a69a99fd6bbdc9b6f6a422dd8cd15a;hp=0b6de7e17d7db8fc4897f844c889acf549cda562;hpb=c80fb2d538e0ccacb2523b762888db5ddada2a6e;p=pysvrt.git
diff --git a/vignette_set.py b/vignette_set.py
index 0b6de7e..aef23d8 100755
--- a/vignette_set.py
+++ b/vignette_set.py
@@ -18,7 +18,7 @@
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
-# along with selector. If not, see .
+# along with svrt. If not, see .
import torch
from math import sqrt
@@ -41,11 +41,16 @@ def generate_one_batch(s):
class VignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
@@ -83,11 +88,16 @@ class VignetteSet:
######################################################################
class CompressedVignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
self.targets = []
self.input_storages = []