Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 9679236..7cb8d4f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -102,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8)
 
 parser.add_argument("--snake_nb_colors", type=int, default=5)
 
-parser.add_argument("--snake_length", type=int, default=400)
+parser.add_argument("--snake_length", type=int, default=200)
 
 ######################################################################
 
@@ -143,8 +143,8 @@ default_args = {
         "batch_size": 25,
     },
     "snake": {
-        "nb_epochs": 25,
-        "batch_size": 20,
+        "nb_epochs": 5,
+        "batch_size": 25,
     },
 }
 
@@ -689,7 +689,7 @@ class TaskSnake(Task):
         self.device = device
         self.prompt_length = prompt_length
 
-        self.train_input, self.train_prior_visits = snake.generate_sequences(
+        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
             nb_train_samples,
             height,
             width,
@@ -698,7 +698,7 @@ class TaskSnake(Task):
             prompt_length,
             self.device,
         )
-        self.test_input, self.test_prior_visits = snake.generate_sequences(
+        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
             nb_test_samples,
             height,
             width,