Update.
[mygptrnn.git] / tasks.py
index 58638ed..727b196 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -58,7 +58,7 @@ def masked_inplace_autoregression(
 
 
 class Task:
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         pass
 
     def vocabulary_size(self):
@@ -328,7 +328,7 @@ class PicoCLVR(Task):
         self.train_input = self.tensorize(self.train_descr)
         self.test_input = self.tensorize(self.test_descr)
 
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -1473,6 +1473,8 @@ class Grid(Task):
         nb_test_samples,
         batch_size,
         size,
+        nb_shapes,
+        nb_colors,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1480,7 +1482,9 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(size=size)
+        self.grid_factory = grid.GridFactory(
+            size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+        )
 
         if logger is not None:
             logger(
@@ -1515,11 +1519,13 @@ class Grid(Task):
         self.train_input = self.str2tensor(self.train_descr)
         self.test_input = self.str2tensor(self.test_descr)
 
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
+        if desc is None:
+            desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
             yield self.trim(batch)
 
@@ -1618,11 +1624,13 @@ class QMLP(Task):
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
+        if desc is None:
+            desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
             yield batch