Update.
[picoclvr.git] / problems.py
index 9321194..ac16df4 100755 (executable)
@@ -289,36 +289,38 @@ class ProblemAddition(Problem):
 
 
 class ProblemMixing(Problem):
-    def __init__(self, height=4, width=4, nb_time_steps=9, hard=False):
+    def __init__(
+        self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
+    ):
         self.height = height
         self.width = width
         self.nb_time_steps = nb_time_steps
         self.hard = hard
+        self.random_start = random_start
 
     def start_random(self, nb):
         y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
 
-        # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+        if self.random_start:
+            i = (
+                torch.arange(self.height)
+                .reshape(1, -1, 1)
+                .expand(nb, self.height, self.width)
+            )
+            j = (
+                torch.arange(self.width)
+                .reshape(1, 1, -1)
+                .expand(nb, self.height, self.width)
+            )
 
-        i = (
-            torch.arange(self.height)
-            .reshape(1, -1, 1)
-            .expand(nb, self.height, self.width)
-        )
-        j = (
-            torch.arange(self.width)
-            .reshape(1, 1, -1)
-            .expand(nb, self.height, self.width)
-        )
+            ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
+            rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
 
-        ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
-        rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
+            m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
 
-        m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+            y = y * m + self.height * self.width * (1 - m)
 
-        y = (y * m + self.height * self.width * (1 - m)).reshape(
-            nb, self.height, self.width
-        )
+        y = y.reshape(nb, self.height, self.width)
 
         return y