Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 23 Jul 2023 18:29:08 +0000 (20:29 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 23 Jul 2023 18:29:08 +0000 (20:29 +0200)
main.py
problems.py [new file with mode: 0755]
tasks.py

diff --git a/main.py b/main.py
index 1b0d39a..0f1fbb5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,7 +12,7 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
-import mygpt, tasks
+import mygpt, tasks, problems
 
 ######################################################################
 
@@ -335,19 +335,19 @@ picoclvr_pruner_eval = (
 
 if args.task == "sandbox":
     if args.sandbox_level == 0:
-        problem = tasks.ProblemLevel0(
+        problem = problems.ProblemLevel0(
             nb_sentences=args.sandbox_levels_nb_items,
             len_prompt=args.sandbox_levels_len_source,
             len_result=args.sandbox_levels_len_result,
         )
     elif args.sandbox_level == 1:
-        problem = tasks.ProblemLevel1(
+        problem = problems.ProblemLevel1(
             nb_operators=args.sandbox_levels_nb_items,
             len_source=args.sandbox_levels_len_source,
             len_result=args.sandbox_levels_len_result,
         )
     elif args.sandbox_level == 2:
-        problem = tasks.ProblemLevel2(
+        problem = problems.ProblemLevel2(
             len_source=args.sandbox_levels_len_source,
             len_result=args.sandbox_levels_len_result,
         )
@@ -356,7 +356,7 @@ if args.task == "sandbox":
 
     task = tasks.SandBox(
         problem,
-        # tasks.ProblemAddition(zero_padded=False, inverted_result=False),
+        # problems.ProblemAddition(zero_padded=False, inverted_result=False),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
diff --git a/problems.py b/problems.py
new file mode 100755 (executable)
index 0000000..78bb64e
--- /dev/null
@@ -0,0 +1,159 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class Problem:
+    def generate_sequences(self, nb):
+        pass
+
+    def seq2str(self, seq):
+        return "[NOT IMPLEMENTED]"
+
+
+####################
+
+
+class ProblemLevel0(Problem):
+    def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
+        self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
+        self.seq[:, len_prompt] = 10
+
+    def generate_sequences(self, nb):
+        sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+        ar_mask = (sequences == 10).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+
+class ProblemLevel1(Problem):
+    def __init__(self, nb_operators=100, len_source=5, len_result=8):
+        self.len_source = len_source
+        self.len_result = len_result
+        self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
+        self.operators = F.one_hot(
+            torch.rand(nb_operators, len_result, len_source).argmax(-1),
+            num_classes=len_source,
+        )
+
+    def generate_sequences(self, nb):
+        nb_operators = torch.randint(self.operators.size(0), (nb,))
+        operators = self.operators[nb_operators]
+        nb_operators = (
+            nb_operators[:, None]
+            // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
+        ) % 10
+        marker1 = torch.full((nb, 1), 10)
+        # source = torch.randint(10, (nb, self.len_source))
+        source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+        marker2 = torch.full((nb, 1), 11)
+        result = operators.bmm(source[:, :, None]).squeeze(-1)
+        sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
+        ar_mask = (sequences == 11).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+class ProblemLevel2(Problem):
+    def __init__(self, len_source=5, len_result=8):
+        self.len_source = len_source
+        self.len_result = len_result
+
+    def generate_sequences(self, nb):
+        operators = F.one_hot(
+            torch.rand(nb, self.len_result, self.len_source).argmax(-1),
+            num_classes=self.len_source,
+        )
+        source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+        marker1 = torch.full((nb, 1), 10)
+        result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
+        marker2 = torch.full((nb, 1), 11)
+        source2 = torch.randint(10, (nb, self.len_source))
+        marker3 = torch.full((nb, 1), 12)
+        result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
+
+        sequences = torch.cat(
+            (source1, marker1, result1, marker2, source2, marker3, result2), 1
+        )
+        ar_mask = (sequences == 12).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789>|~"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemAddition(Problem):
+    def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
+        self.nb_digits = nb_digits
+        self.zero_padded = zero_padded
+        self.inverted_result = inverted_result
+        self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
+        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+    def tensorize(self, strings):
+        len_max = max([len(x) for x in strings])
+        return torch.cat(
+            [
+                torch.tensor(
+                    [
+                        [self.char2id[c] for c in s + "$" * (len_max - len(s))]
+                        for s in strings
+                    ]
+                )
+            ],
+            0,
+        )
+
+    def generate_sequences(self, nb):
+        sequences = []
+        for k in range(nb):
+            a, b = torch.randint(10**self.nb_digits, (2,))
+            c = a + b
+            a, b, c = str(a.item()), str(b.item()), str(c.item())
+            if self.zero_padded:
+                a = "0" * (self.nb_digits - len(a)) + a
+                b = "0" * (self.nb_digits - len(b)) + b
+                c = "0" * (self.nb_digits + 1 - len(c)) + c
+            if self.inverted_result:
+                c = c[::-1]
+            sequences.append(f"{a}+{b}={c}$")
+
+        sequences = self.tensorize(sequences)
+        ar_mask = (sequences == self.char2id["="]).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join(self.id2char[x.item()] for x in seq)
+
+
+# class ProblemUnion(Problem):
+# problems = [ProblemByheart()]
+# nb_common_codes = 100
+
+# def generate_sequences(nb_samples):
+# problem_indexes = torch.randint(len(problems), (nb_samples,))
+# nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+# print(f"{nb_samples_per_problem}")
+# all_seq = []
+# for nb, p in zip(nb_samples_per_problem, problems):
+# all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+# return all_seq
+
+# for strain, stest in zip(train_seq, test_seq):
+# s = torch.cat((strain, stest), 0)
+
index 17904d8..421aee4 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -72,158 +72,9 @@ class Task:
         pass
 
 
-######################################################################
-
-
-class Problem:
-    def generate_sequences(self, nb):
-        pass
-
-    def seq2str(self, seq):
-        return "[NOT IMPLEMENTED]"
-
-
-####################
-
-
-class ProblemLevel0(Problem):
-    def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
-        self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
-        self.seq[:, len_prompt] = 10
-
-    def generate_sequences(self, nb):
-        sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
-        ar_mask = (sequences == 10).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-
-class ProblemLevel1(Problem):
-    def __init__(self, nb_operators=100, len_source=5, len_result=8):
-        self.len_source = len_source
-        self.len_result = len_result
-        self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
-        self.operators = F.one_hot(
-            torch.rand(nb_operators, len_result, len_source).argmax(-1),
-            num_classes=len_source,
-        )
-
-    def generate_sequences(self, nb):
-        nb_operators = torch.randint(self.operators.size(0), (nb,))
-        operators = self.operators[nb_operators]
-        nb_operators = (
-            nb_operators[:, None]
-            // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
-        ) % 10
-        marker1 = torch.full((nb, 1), 10)
-        # source = torch.randint(10, (nb, self.len_source))
-        source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
-        marker2 = torch.full((nb, 1), 11)
-        result = operators.bmm(source[:, :, None]).squeeze(-1)
-        sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
-        ar_mask = (sequences == 11).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join("0123456789|>"[x.item()] for x in seq)
-
-
-class ProblemLevel2(Problem):
-    def __init__(self, len_source=5, len_result=8):
-        self.len_source = len_source
-        self.len_result = len_result
-
-    def generate_sequences(self, nb):
-        operators = F.one_hot(
-            torch.rand(nb, self.len_result, self.len_source).argmax(-1),
-            num_classes=self.len_source,
-        )
-        source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
-        marker1 = torch.full((nb, 1), 10)
-        result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
-        marker2 = torch.full((nb, 1), 11)
-        source2 = torch.randint(10, (nb, self.len_source))
-        marker3 = torch.full((nb, 1), 12)
-        result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
-
-        sequences = torch.cat(
-            (source1, marker1, result1, marker2, source2, marker3, result2), 1
-        )
-        ar_mask = (sequences == 12).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join("0123456789>|~"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemAddition(Problem):
-    def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
-        self.nb_digits = nb_digits
-        self.zero_padded = zero_padded
-        self.inverted_result = inverted_result
-        self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
-        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
-    def tensorize(self, strings):
-        len_max = max([len(x) for x in strings])
-        return torch.cat(
-            [
-                torch.tensor(
-                    [
-                        [self.char2id[c] for c in s + "$" * (len_max - len(s))]
-                        for s in strings
-                    ]
-                )
-            ],
-            0,
-        )
-
-    def generate_sequences(self, nb):
-        sequences = []
-        for k in range(nb):
-            a, b = torch.randint(10**self.nb_digits, (2,))
-            c = a + b
-            a, b, c = str(a.item()), str(b.item()), str(c.item())
-            if self.zero_padded:
-                a = "0" * (self.nb_digits - len(a)) + a
-                b = "0" * (self.nb_digits - len(b)) + b
-                c = "0" * (self.nb_digits + 1 - len(c)) + c
-            if self.inverted_result:
-                c = c[::-1]
-            sequences.append(f"{a}+{b}={c}$")
-
-        sequences = self.tensorize(sequences)
-        ar_mask = (sequences == self.char2id["="]).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join(self.id2char[x.item()] for x in seq)
-
-
-# class ProblemUnion(Problem):
-# problems = [ProblemByheart()]
-# nb_common_codes = 100
-
-# def generate_sequences(nb_samples):
-# problem_indexes = torch.randint(len(problems), (nb_samples,))
-# nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
-# print(f"{nb_samples_per_problem}")
-# all_seq = []
-# for nb, p in zip(nb_samples_per_problem, problems):
-# all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
-# return all_seq
-
-# for strain, stest in zip(train_seq, test_seq):
-# s = torch.cat((strain, stest), 0)
-
 ####################
 
+import problems
 
 class SandBox(Task):
     def __init__(
@@ -1283,7 +1134,7 @@ class RPL(Task):
         )
 
         if save_attention_image is not None:
-            ns=torch.randint(self.text_input.size(0),(1,)).item()
+            ns=torch.randint(self.test_input.size(0),(1,)).item()
             input = self.test_input[ns:ns+1].clone()
             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
             input = input[:, :last].to(self.device)
@@ -1297,7 +1148,7 @@ class RPL(Task):
                 ram = model.retrieve_attention()
                 model.record_attention(False)
 
-            tokens_output = [self.id2token[i.item()] for i in input[ns]]
+            tokens_output = [self.id2token[i.item()] for i in input[0]]
             tokens_input = ["n/a"] + tokens_output[:-1]
             for n_head in range(ram[0].size(1)):
                 filename = os.path.join(