Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jul 2023 18:58:33 +0000 (08:58 -1000)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jul 2023 18:58:33 +0000 (08:58 -1000)
README.txt
main.py
problems.py
tasks.py

index a4cd46b..d4740f3 100644 (file)
@@ -18,3 +18,6 @@ For the arithmetic expressions experiments
 
 ./main.py --task=expr --nb_blocks=48 --dim_model=1024 --nb_train_samples=2500000 --result_dir=results_expr_48b_d1024_2.5M
 ======================================================================
+25.07.2023
+
+./main.py --task=sandbox --nb_train_samples=10000 --nb_test_samples=1000 --nb_blocks=4 --nb_heads=1 --nb_epochs=20
diff --git a/main.py b/main.py
index 9c28e47..ed4adf5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -366,7 +366,7 @@ if args.task == "sandbox":
         # problem,
         # problems.ProblemAddition(zero_padded=False, inverted_result=False),
         # problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
-        problems.ProblemTwoTargets(len_total=12, len_targets=4),
+        problems.ProblemTwoTargets(len_total=16, len_targets=4),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
index aa3acf0..2e0ca36 100755 (executable)
@@ -47,14 +47,22 @@ class ProblemTwoTargets(Problem):
         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
         sequences = torch.cat(
-            (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
+            (
+                s,
+                torch.full((nb, 1), 12),
+                a1,
+                torch.full((nb, 1), 12),
+                a2,
+                torch.full((nb, 1), 12),
+            ),
+            1,
         )
         ar_mask = (sequences == 12).long()
         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
         return sequences, ar_mask
 
     def seq2str(self, seq):
-        return "".join("0123456789+-|"[x.item()] for x in seq)
+        return "".join("0123456789-+|"[x.item()] for x in seq)
 
 
 ####################
index 0143ab2..cc3aea0 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -181,7 +181,9 @@ class SandBox(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
-        if save_attention_image is not None:
+        if save_attention_image is None:
+            logger("no save_attention_image (is pycairo installed?)")
+        else:
             for k in range(10):
                 ns = torch.randint(self.test_input.size(0), (1,)).item()
                 input = self.test_input[ns : ns + 1].clone()
@@ -1167,7 +1169,9 @@ class RPL(Task):
             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
         )
 
-        if save_attention_image is not None:
+        if save_attention_image is None:
+            logger("no save_attention_image (is pycairo installed?)")
+        else:
             ns = torch.randint(self.test_input.size(0), (1,)).item()
             input = self.test_input[ns : ns + 1].clone()
             last = (input != self.t_nul).max(0).values.nonzero().max() + 3