Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jul 2023 18:22:39 +0000 (08:22 -1000)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jul 2023 18:22:39 +0000 (08:22 -1000)
graph.py
main.py
problems.py
tasks.py

index 2c7caf8..6db9ed7 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -110,7 +110,7 @@ def save_attention_image(
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
+        ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5)
         ctx.show_text(s)
 
     for k, t in enumerate(tokens_output):
@@ -146,7 +146,7 @@ def save_attention_image(
 if __name__ == "__main__":
     import mygpt
 
-    tokens_output = ["<wat>", 2, 3, 4, "<end>"]
+    tokens_output = ["<wat>", "-", 3, 4, "<end>"]
     tokens_input = [""] + tokens_output[:-1]
 
     vocabulary_size = 3
diff --git a/main.py b/main.py
index 68b946a..9c28e47 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -365,7 +365,8 @@ if args.task == "sandbox":
     task = tasks.SandBox(
         # problem,
         # problems.ProblemAddition(zero_padded=False, inverted_result=False),
-        problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
+        # problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
+        problems.ProblemTwoTargets(len_total=12, len_targets=4),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
index 7b1d698..aa3acf0 100755 (executable)
@@ -22,32 +22,39 @@ class Problem:
 
 
 class ProblemTwoTargets(Problem):
-    def __init__(self, len_total=10, len_target=2):
-        assert len_total >= 3 * (2 + len_target) - 1
+    def __init__(self, len_total=10, len_targets=3):
+        assert len_targets >= 3
+        assert len_total >= 3 * len_targets - 1
         self.len_total = len_total
-        self.len_target = len_target
+        self.len_targets = len_targets
 
     def generate_sequences(self, nb):
         k = torch.arange(self.len_total)[None, :]
-        l = torch.randint(self.len_total, (2, nb))[:, :, None] + 1
-        i = torch.randint(10, (2, nb))[:, :, None]
-        a = l[0]
-        b = l[0] + 1 + l[1]
-        c = l[0] + 1 + l[1] + 1 + l[0]
-        sequences = (
-            (k < a) * i[0]
-            + (k == a) * 10
-            + (k > a) * (k < b) * i[1]
-            + (k == b) * 11
-            + (k > b) * (k < c) * i[1]
-            + (k >= c) * 12
+        s = torch.randint(10, (nb, self.len_total))
+        l = torch.rand(nb, self.len_total)
+        l = l * (k <= self.len_total - self.len_targets).long()
+        k1 = l.argmax(dim=1, keepdim=True)
+        m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
+        s = s * m + 10 * (1 - m)
+        l = l * (
+            1
+            - (k + self.len_targets - 1 >= k1).long()
+            * (k < k1 + self.len_targets).long()
         )
-        ar_mask = (sequences == 11).long()
+        k2 = l.argmax(dim=1, keepdim=True)
+        m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
+        s = s * m + 11 * (1 - m)
+        a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
+        a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
+        sequences = torch.cat(
+            (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
+        )
+        ar_mask = (sequences == 12).long()
         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
         return sequences, ar_mask
 
     def seq2str(self, seq):
-        return "".join("0123456789|>_"[x.item()] for x in seq)
+        return "".join("0123456789+-|"[x.item()] for x in seq)
 
 
 ####################
@@ -212,18 +219,8 @@ class ProblemAddition(Problem):
         return "".join(self.id2char[x.item()] for x in seq)
 
 
-# class ProblemUnion(Problem):
-# problems = [ProblemByheart()]
-# nb_common_codes = 100
-
-# def generate_sequences(nb_samples):
-# problem_indexes = torch.randint(len(problems), (nb_samples,))
-# nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
-# print(f"{nb_samples_per_problem}")
-# all_seq = []
-# for nb, p in zip(nb_samples_per_problem, problems):
-# all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
-# return all_seq
-
-# for strain, stest in zip(train_seq, test_seq):
-# s = torch.cat((strain, stest), 0)
+if __name__ == "__main__":
+    p = ProblemTwoTargets(12, 4)
+    s, m = p.generate_sequences(10)
+    for x in s:
+        print(p.seq2str(x))
index 038a8ac..0143ab2 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -182,36 +182,37 @@ class SandBox(Task):
         )
 
         if save_attention_image is not None:
-            ns = torch.randint(self.test_input.size(0), (1,)).item()
-            input = self.test_input[ns : ns + 1].clone()
-
-            with torch.autograd.no_grad():
-                t = model.training
-                model.eval()
-                model.record_attention(True)
-                model(BracketedSequence(input))
-                model.train(t)
-                ram = model.retrieve_attention()
-                model.record_attention(False)
-
-            tokens_output = [c for c in self.problem.seq2str(input[0])]
-            tokens_input = ["n/a"] + tokens_output[:-1]
-            for n_head in range(ram[0].size(1)):
-                filename = os.path.join(
-                    result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
-                )
-                attention_matrices = [m[0, n_head] for m in ram]
-                save_attention_image(
-                    filename,
-                    tokens_input,
-                    tokens_output,
-                    attention_matrices,
-                    k_top=10,
-                    # min_total_attention=0.9,
-                    token_gap=12,
-                    layer_gap=50,
-                )
-                logger(f"wrote {filename}")
+            for k in range(10):
+                ns = torch.randint(self.test_input.size(0), (1,)).item()
+                input = self.test_input[ns : ns + 1].clone()
+
+                with torch.autograd.no_grad():
+                    t = model.training
+                    model.eval()
+                    model.record_attention(True)
+                    model(BracketedSequence(input))
+                    model.train(t)
+                    ram = model.retrieve_attention()
+                    model.record_attention(False)
+
+                tokens_output = [c for c in self.problem.seq2str(input[0])]
+                tokens_input = ["n/a"] + tokens_output[:-1]
+                for n_head in range(ram[0].size(1)):
+                    filename = os.path.join(
+                        result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
+                    )
+                    attention_matrices = [m[0, n_head] for m in ram]
+                    save_attention_image(
+                        filename,
+                        tokens_input,
+                        tokens_output,
+                        attention_matrices,
+                        k_top=10,
+                        # min_total_attention=0.9,
+                        token_gap=12,
+                        layer_gap=50,
+                    )
+                    logger(f"wrote {filename}")
 
 
 ######################################################################