Update.
[picoclvr.git] / tasks.py
index a53d213..08aa8ca 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1475,6 +1475,7 @@ class Grid(Task):
         nb_test_samples,
         batch_size,
         size,
+        fraction_play=0.0,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1490,10 +1491,12 @@ class Grid(Task):
             )
 
         self.train_descr = self.grid_factory.generate_samples(
-            nb_train_samples, lambda r: tqdm.tqdm(r)
+            nb=nb_train_samples,
+            fraction_play=fraction_play,
+            progress_bar=lambda r: tqdm.tqdm(r),
         )
         self.test_descr = self.grid_factory.generate_samples(
-            nb_test_samples, lambda r: tqdm.tqdm(r)
+            nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
         )
 
         # Build the tokenizer