Update.
[picoclvr.git] / problems.py
index 22b6517..28e4f7b 100755 (executable)
@@ -24,8 +24,6 @@ class Problem:
 
 
 ####################
-
-
 class ProblemDegradation(Problem):
     def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
         assert value_max // nb_state_tokens >= 2
@@ -285,9 +283,100 @@ class ProblemAddition(Problem):
         return "".join(self.id2char[x.item()] for x in seq)
 
 
+####################
+
+
+class ProblemMixing(Problem):
+    def __init__(self, height=3, width=3, nb_time_steps=12, hard=False):
+        self.height = height
+        self.width = width
+        self.nb_time_steps = nb_time_steps
+        self.hard = hard
+
+    def start(self, nb):
+        return (
+            torch.arange(self.height * self.width)
+            .reshape(1, 1, self.height, self.width)
+            .expand(nb, -1, -1, -1)
+        )
+
+    def moves(self, x):
+        y = (
+            x[:, None, :, :]
+            .expand(-1, self.height * 2 + self.width * 2, -1, -1)
+            .clone()
+        )
+        k = 0
+
+        for i in range(self.height):
+            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
+            k += 1
+            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
+            k += 1
+
+        for j in range(self.width):
+            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
+            k += 1
+            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
+            k += 1
+
+        return y
+
+    def generate_sequences(self, nb):
+        y = self.start(nb)
+        x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+
+        seq = [x.flatten(1)]
+
+        for t in range(self.nb_time_steps - 1):
+            y = self.moves(x)
+            x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+            seq.append(x.flatten(1))
+
+        if self.hard:
+            seq.reverse()
+
+        seq = torch.cat(seq, dim=1)
+        return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+    def compute_nb_correct(self, input, ar_mask, result):
+        a = [
+            x.reshape(result.size(0), self.height, self.width)
+            for x in result.split(self.height * self.width, dim=1)
+        ]
+        if self.hard:
+            a.reverse()
+
+        x = a[0]
+
+        y = self.start(result.size(0)).to(x.device)
+        d = (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+        for t in range(self.nb_time_steps - 1):
+            x0, x = a[t], a[t + 1]
+            y = self.moves(x0)
+            d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+        nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
+
+        return nb_total, nb_correct
+
+    def seq2str(self, seq):
+        return " | ".join(
+            [
+                " ".join(
+                    ["-".join([f"{x:02d}" for x in s]) for s in r.split(self.width)]
+                )
+                for r in seq.split(self.height * self.width)
+            ]
+        )
+
+
+####################
+
 if __name__ == "__main__":
-    p = ProblemDegradation(hard=False)
+    p = ProblemMixing(hard=True)
     s, m = p.generate_sequences(10000)
-    for x in s[:100]:
+    for x in s[:5]:
         print(p.seq2str(x))
     print(p.compute_nb_correct(None, None, s))