Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 5506bbd..1b0d39a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -81,15 +81,15 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # rpl options
 
-parser.add_argument("--rpl-nb_starting_values", type=int, default=5)
+parser.add_argument("--rpl_nb_starting_values", type=int, default=5)
 
-parser.add_argument("--rpl-max_input", type=int, default=9)
+parser.add_argument("--rpl_max_input", type=int, default=9)
 
-parser.add_argument("--rpl-prog_len", type=int, default=10)
+parser.add_argument("--rpl_prog_len", type=int, default=10)
 
-parser.add_argument("--rpl-nb_runs", type=int, default=8)
+parser.add_argument("--rpl_nb_runs", type=int, default=8)
 
-parser.add_argument("--rpl-no-prog", action="store_true", default=False)
+parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 ##############################
 # sandbox options
@@ -518,12 +518,12 @@ else:
 
 if args.task == "expr" and args.expr_input_file is not None:
     task.produce_results(
-        nb_epochs_finished,
-        model,
-        args.result_dir,
-        log_string,
-        args.deterministic_synthesis,
-        args.expr_input_file,
+        n_epoch=nb_epochs_finished,
+        model=model,
+        result_dir=args.result_dir,
+        logger=log_string,
+        deterministic_synthesis=args.deterministic_synthesis,
+        input_file=args.expr_input_file,
     )
 
     exit(0)
@@ -599,11 +599,11 @@ nb_samples_seen = 0
 
 if nb_epochs_finished >= nb_epochs:
     task.produce_results(
-        nb_epochs_finished,
-        model,
-        args.result_dir,
-        log_string,
-        args.deterministic_synthesis,
+        n_epoch=nb_epochs_finished,
+        model=model,
+        result_dir=args.result_dir,
+        logger=log_string,
+        deterministic_synthesis=args.deterministic_synthesis,
     )
 
 for n_epoch in range(nb_epochs_finished, nb_epochs):
@@ -657,7 +657,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         task.produce_results(
-            n_epoch, model, args.result_dir, log_string, args.deterministic_synthesis
+            n_epoch=n_epoch,
+            model=model,
+            result_dir=args.result_dir,
+            logger=log_string,
+            deterministic_synthesis=args.deterministic_synthesis,
         )
 
     checkpoint = {