Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 22 Oct 2023 15:07:10 +0000 (17:07 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 22 Oct 2023 15:07:10 +0000 (17:07 +0200)
problems.py

index 28e4f7b..51e90ed 100755 (executable)
@@ -293,13 +293,26 @@ class ProblemMixing(Problem):
         self.nb_time_steps = nb_time_steps
         self.hard = hard
 
-    def start(self, nb):
-        return (
-            torch.arange(self.height * self.width)
-            .reshape(1, 1, self.height, self.width)
-            .expand(nb, -1, -1, -1)
+    def start_random(self, nb):
+        y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
+
+        m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+
+        y = (y * m + self.height * self.width * (1 - m)).reshape(
+            nb, self.height, self.width
         )
 
+        return y
+
+    def start_error(self, x):
+        x = x.flatten(1)
+        u = torch.arange(self.height * self.width).reshape(1, -1)
+        m = ((x - u).abs() == 0).long()
+        d = (x - (m * u + (1-m) * self.height * self.width)).abs().sum(-1) + (
+            m.sum(dim=-1) != self.height * self.width // 2
+        ).long()
+        return d
+
     def moves(self, x):
         y = (
             x[:, None, :, :]
@@ -323,8 +336,7 @@ class ProblemMixing(Problem):
         return y
 
     def generate_sequences(self, nb):
-        y = self.start(nb)
-        x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+        x = self.start_random(nb)
 
         seq = [x.flatten(1)]
 
@@ -349,8 +361,7 @@ class ProblemMixing(Problem):
 
         x = a[0]
 
-        y = self.start(result.size(0)).to(x.device)
-        d = (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+        d = self.start_error(x)
 
         for t in range(self.nb_time_steps - 1):
             x0, x = a[t], a[t + 1]
@@ -375,7 +386,7 @@ class ProblemMixing(Problem):
 ####################
 
 if __name__ == "__main__":
-    p = ProblemMixing(hard=True)
+    p = ProblemMixing(width=4, hard=True)
     s, m = p.generate_sequences(10000)
     for x in s[:5]:
         print(p.seq2str(x))