From 59600257e0eda86816a43676c5ffbe598d78bdb5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jul 2023 08:58:33 -1000 Subject: [PATCH] Update. --- README.txt | 3 +++ main.py | 2 +- problems.py | 12 ++++++++++-- tasks.py | 8 ++++++-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/README.txt b/README.txt index a4cd46b..d4740f3 100644 --- a/README.txt +++ b/README.txt @@ -18,3 +18,6 @@ For the arithmetic expressions experiments ./main.py --task=expr --nb_blocks=48 --dim_model=1024 --nb_train_samples=2500000 --result_dir=results_expr_48b_d1024_2.5M ====================================================================== +25.07.2023 + +./main.py --task=sandbox --nb_train_samples=10000 --nb_test_samples=1000 --nb_blocks=4 --nb_heads=1 --nb_epochs=20 diff --git a/main.py b/main.py index 9c28e47..ed4adf5 100755 --- a/main.py +++ b/main.py @@ -366,7 +366,7 @@ if args.task == "sandbox": # problem, # problems.ProblemAddition(zero_padded=False, inverted_result=False), # problems.ProblemLenId(len_max=args.sandbox_levels_len_source), - problems.ProblemTwoTargets(len_total=12, len_targets=4), + problems.ProblemTwoTargets(len_total=16, len_targets=4), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, diff --git a/problems.py b/problems.py index aa3acf0..2e0ca36 100755 --- a/problems.py +++ b/problems.py @@ -47,14 +47,22 @@ class ProblemTwoTargets(Problem): a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :]) a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :]) sequences = torch.cat( - (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1 + ( + s, + torch.full((nb, 1), 12), + a1, + torch.full((nb, 1), 12), + a2, + torch.full((nb, 1), 12), + ), + 1, ) ar_mask = (sequences == 12).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask def seq2str(self, seq): - return "".join("0123456789+-|"[x.item()] for x in seq) + return "".join("0123456789-+|"[x.item()] for x in seq) #################### diff --git a/tasks.py b/tasks.py index 0143ab2..cc3aea0 100755 --- a/tasks.py +++ b/tasks.py @@ -181,7 +181,9 @@ class SandBox(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) - if save_attention_image is not None: + if save_attention_image is None: + logger("no save_attention_image (is pycairo installed?)") + else: for k in range(10): ns = torch.randint(self.test_input.size(0), (1,)).item() input = self.test_input[ns : ns + 1].clone() @@ -1167,7 +1169,9 @@ class RPL(Task): f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" ) - if save_attention_image is not None: + if save_attention_image is None: + logger("no save_attention_image (is pycairo installed?)") + else: ns = torch.randint(self.test_input.size(0), (1,)).item() input = self.test_input[ns : ns + 1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 -- 2.20.1