Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 55f2c2f..9198edc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -89,7 +89,9 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # filetask
 
-parser.add_argument("--filetask_file", type=str, default=None)
+parser.add_argument("--filetask_train_file", type=str, default=None)
+
+parser.add_argument("--filetask_test_file", type=str, default=None)
 
 ##############################
 # rpl options
@@ -403,10 +405,11 @@ picoclvr_pruner_eval = (
 
 if args.task == "file":
     assert (
-        args.filetask_file is not None
-    ), "You have to specify the task file with --filetask_file <filename>"
+        args.filetask_train_file is not None and args.filetask_test_file is not None
+    ), "You have to specify the task train and test files"
     task = tasks.TaskFromFile(
-        args.filetask_file,
+        args.filetask_train_file,
+        args.filetask_test_file,
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,