Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 8081850..ff831f4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -89,13 +89,13 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # rpl options
 
-parser.add_argument("--rpl_nb_starting_values", type=int, default=5)
+parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
 
 parser.add_argument("--rpl_max_input", type=int, default=9)
 
-parser.add_argument("--rpl_prog_len", type=int, default=10)
+parser.add_argument("--rpl_prog_len", type=int, default=8)
 
-parser.add_argument("--rpl_nb_runs", type=int, default=8)
+parser.add_argument("--rpl_nb_runs", type=int, default=5)
 
 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
@@ -249,10 +249,10 @@ default_task_args = {
         "nb_test_samples": 10000,
     },
     "rpl": {
-        "model": "352M",
+        "model": "122M",
         "nb_epochs": 50,
-        "batch_size": 10,
-        "nb_train_samples": 2500000,
+        "batch_size": 5,
+        "nb_train_samples": 1000000,
         "nb_test_samples": 10000,
     },
     "world": {