Stupid typo in the headers ...
[pysvrt.git] / vignette_set.py
index 0b6de7e..b95a1db 100755 (executable)
@@ -18,7 +18,7 @@
 #  General Public License for more details.
 #
 #  You should have received a copy of the GNU General Public License
-#  along with selector.  If not, see <http://www.gnu.org/licenses/>.
+#  along with pysvrt.  If not, see <http://www.gnu.org/licenses/>.
 
 import torch
 from math import sqrt
@@ -41,11 +41,16 @@ def generate_one_batch(s):
 
 class VignetteSet:
 
-    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+    def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+        if nb_samples%batch_size > 0:
+            print('nb_samples must be a mutiple of batch_size')
+            raise
+
         self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_batches
+        self.nb_batches = nb_samples // batch_size
         self.nb_samples = self.nb_batches * self.batch_size
 
         seeds = torch.LongTensor(self.nb_batches).random_()
@@ -83,11 +88,16 @@ class VignetteSet:
 ######################################################################
 
 class CompressedVignetteSet:
-    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+    def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+        if nb_samples%batch_size > 0:
+            print('nb_samples must be a mutiple of batch_size')
+            raise
+
         self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
-        self.nb_batches = nb_batches
+        self.nb_batches = nb_samples // batch_size
         self.nb_samples = self.nb_batches * self.batch_size
         self.targets = []
         self.input_storages = []