From d6f73f1d5093fb098e822e14db382dd3a1c63a2a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:35:59 +0200 Subject: [PATCH] Update. --- main.py | 33 ++++++++++++++++++++++++++++++++- tasks.py | 2 +- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index e3fd9f0..0d4930d 100755 --- a/main.py +++ b/main.py @@ -82,6 +82,17 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options +parser.add_argument("--sandbox_level", type=int, default=0) + +parser.add_argument("--sandbox_levels_nb_items", type=int, default=25) + +parser.add_argument("--sandbox_levels_len_source", type=int, default=5) + +parser.add_argument("--sandbox_levels_len_result", type=int, default=8) + +############################## +# picoclvr options + parser.add_argument("--picoclvr_nb_colors", type=int, default=5) parser.add_argument("--picoclvr_height", type=int, default=12) @@ -265,8 +276,28 @@ picoclvr_pruner_eval = ( ###################################################################### if args.task == "sandbox": + if args.sandbox_level == 0: + problem = tasks.ProblemLevel0( + nb_sentences=args.sandbox_levels_nb_items, + len_prompt=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 1: + problem = tasks.ProblemLevel1( + nb_operators=args.sandbox_levels_nb_items, + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + elif args.sandbox_level == 2: + problem = tasks.ProblemLevel2( + len_source=args.sandbox_levels_len_source, + len_result=args.sandbox_levels_len_result, + ) + else: + raise ValueError(f"Unknown sandbox level {args.sandbox_level}") + task = tasks.SandBox( - tasks.ProblemLevel2(), + problem, # tasks.ProblemAddition(zero_padded=False, inverted_result=False), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, diff --git a/tasks.py b/tasks.py index 73f61bf..e7c2f75 100755 --- a/tasks.py +++ b/tasks.py @@ -76,7 +76,7 @@ class Problem: class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): - self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): -- 2.20.1