Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 22 Oct 2023 13:35:46 +0000 (15:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 22 Oct 2023 13:35:46 +0000 (15:35 +0200)
main.py
problems.py

diff --git a/main.py b/main.py
index 6e87cda..496a603 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -33,7 +33,7 @@ parser.add_argument(
     "--task",
     type=str,
     default="twotargets",
-    help="byheart, learnop, guessop, degradation, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+    help="byheart, learnop, guessop, mixing, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -162,7 +162,7 @@ parser.add_argument("--expr_input_file", type=str, default=None)
 ##############################
 # Misc
 
-parser.add_argument("--degradation_hard", action="store_true", default=False)
+parser.add_argument("--mixing_hard", action="store_true", default=False)
 
 ######################################################################
 
@@ -254,7 +254,7 @@ default_task_args = {
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
-    "degradation": {
+    "mixing": {
         "model": "37M",
         "batch_size": 25,
         "nb_train_samples": 250000,
@@ -414,9 +414,9 @@ elif args.task == "twotargets":
         device=device,
     )
 
-elif args.task == "degradation":
+elif args.task == "mixing":
     task = tasks.SandBox(
-        problem=problems.ProblemDegradation(hard=args.degradation_hard),
+        problem=problems.ProblemMixing(hard=args.mixing_hard),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
index 22b6517..28e4f7b 100755 (executable)
@@ -24,8 +24,6 @@ class Problem:
 
 
 ####################
-
-
 class ProblemDegradation(Problem):
     def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
         assert value_max // nb_state_tokens >= 2
@@ -285,9 +283,100 @@ class ProblemAddition(Problem):
         return "".join(self.id2char[x.item()] for x in seq)
 
 
+####################
+
+
+class ProblemMixing(Problem):
+    def __init__(self, height=3, width=3, nb_time_steps=12, hard=False):
+        self.height = height
+        self.width = width
+        self.nb_time_steps = nb_time_steps
+        self.hard = hard
+
+    def start(self, nb):
+        return (
+            torch.arange(self.height * self.width)
+            .reshape(1, 1, self.height, self.width)
+            .expand(nb, -1, -1, -1)
+        )
+
+    def moves(self, x):
+        y = (
+            x[:, None, :, :]
+            .expand(-1, self.height * 2 + self.width * 2, -1, -1)
+            .clone()
+        )
+        k = 0
+
+        for i in range(self.height):
+            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
+            k += 1
+            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
+            k += 1
+
+        for j in range(self.width):
+            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
+            k += 1
+            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
+            k += 1
+
+        return y
+
+    def generate_sequences(self, nb):
+        y = self.start(nb)
+        x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+
+        seq = [x.flatten(1)]
+
+        for t in range(self.nb_time_steps - 1):
+            y = self.moves(x)
+            x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+            seq.append(x.flatten(1))
+
+        if self.hard:
+            seq.reverse()
+
+        seq = torch.cat(seq, dim=1)
+        return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+    def compute_nb_correct(self, input, ar_mask, result):
+        a = [
+            x.reshape(result.size(0), self.height, self.width)
+            for x in result.split(self.height * self.width, dim=1)
+        ]
+        if self.hard:
+            a.reverse()
+
+        x = a[0]
+
+        y = self.start(result.size(0)).to(x.device)
+        d = (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+        for t in range(self.nb_time_steps - 1):
+            x0, x = a[t], a[t + 1]
+            y = self.moves(x0)
+            d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+        nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
+
+        return nb_total, nb_correct
+
+    def seq2str(self, seq):
+        return " | ".join(
+            [
+                " ".join(
+                    ["-".join([f"{x:02d}" for x in s]) for s in r.split(self.width)]
+                )
+                for r in seq.split(self.height * self.width)
+            ]
+        )
+
+
+####################
+
 if __name__ == "__main__":
-    p = ProblemDegradation(hard=False)
+    p = ProblemMixing(hard=True)
     s, m = p.generate_sequences(10000)
-    for x in s[:100]:
+    for x in s[:5]:
         print(p.seq2str(x))
     print(p.compute_nb_correct(None, None, s))