Update.
[mygptrnn.git] / tasks.py
index 4777a11..727b196 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1473,6 +1473,8 @@ class Grid(Task):
         nb_test_samples,
         batch_size,
         size,
+        nb_shapes,
+        nb_colors,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1480,7 +1482,9 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(size=size)
+        self.grid_factory = grid.GridFactory(
+            size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+        )
 
         if logger is not None:
             logger(