Update. master
authorFrançois Fleuret <fleuret@meta.com>
Tue, 17 Jun 2025 15:14:20 +0000 (17:14 +0200)
committerFrançois Fleuret <fleuret@meta.com>
Tue, 17 Jun 2025 15:14:20 +0000 (17:14 +0200)
grid.py
picocrafter.py
tinyae.py
tinymnist.py

diff --git a/grid.py b/grid.py
index d06802d..ac0ebe0 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -9,7 +9,7 @@
 # This code implement a simple system to manipulate formal
 # specifications of tokens on a grid.
 
-import math, re
+import math, re, random
 
 import torch
 
@@ -61,6 +61,7 @@ class FormalGrid:
 
     def constraint_to_fun(self, constraint):
         a, b, c = None, None, None
+        col, row = self.col, self.row
 
         def match(pattern):
             nonlocal a, b, c
@@ -73,79 +74,70 @@ class FormalGrid:
                 return False
 
         if match("([1-9]) is_in_top_half"):
-            return self.row[:, a] < self.grid_height // 2
+            return row[:, a] < self.grid_height // 2
 
         elif match("([1-9]) is_in_bottom_half"):
-            return self.row[:, a] >= self.grid_height // 2
+            return row[:, a] >= self.grid_height // 2
 
         elif match("([1-9]) is_on_left_side"):
-            return self.col[:, a] < self.grid_width // 2
+            return col[:, a] < self.grid_width // 2
 
         elif match("([1-9]) is_on_right_side"):
-            return self.col[:, a] >= self.grid_width // 2
+            return col[:, a] >= self.grid_width // 2
 
         elif match("([1-9]) next_to ([1-9])"):
-            return (self.row[:, a] - self.row[:, b]).abs() + (
-                self.col[:, a] - self.col[:, b]
-            ).abs() <= 1
+            return (row[:, a] - row[:, b]).abs() + (col[:, a] - col[:, b]).abs() == 1
 
         elif match("([1-9]) is_below ([1-9])"):
-            return self.row[:, a] > self.row[:, b]
+            return row[:, a] > row[:, b]
 
         elif match("([1-9]) is_above ([1-9])"):
-            return self.row[:, a] < self.row[:, b]
+            return row[:, a] < row[:, b]
 
         elif match("([1-9]) is_left_of ([1-9])"):
-            return self.col[:, a] < self.col[:, b]
+            return col[:, a] < col[:, b]
 
         elif match("([1-9]) is_right_of ([1-9])"):
-            return self.col[:, a] > self.col[:, b]
+            return col[:, a] > col[:, b]
 
         elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"):
-            return (self.col[:, a] - self.col[:, b]).abs() == (
-                self.row[:, a] - self.row[:, b]
-            ).abs()
+            return (col[:, a] - col[:, b]).abs() == (row[:, a] - row[:, b]).abs()
 
         elif match("([1-9]) ([1-9]) is_vertical"):
-            return self.col[:, a] == self.col[:, b]
+            return col[:, a] == col[:, b]
 
         elif match("([1-9]) ([1-9]) is_horizontal"):
-            return self.row[:, a] == self.row[:, b]
+            return row[:, a] == row[:, b]
 
         elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
-            return (self.col[:, a] - self.col[:, b]) * (
-                self.row[:, a] - self.row[:, c]
-            ) - (self.row[:, a] - self.row[:, b]) * (
-                self.col[:, a] - self.col[:, c]
-            ) == 0
+            return (col[:, a] - col[:, b]) * (row[:, a] - row[:, c]) - (
+                row[:, a] - row[:, b]
+            ) * (col[:, a] - col[:, c]) == 0
 
         elif match("([1-9]) middle_of ([1-9]) ([1-9])"):
-            return (
-                grid_set
-                & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b])
-                & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b])
+            return (col[:, b] + col[:, a] == 2 * col[:, b]) & (
+                row[:, b] + row[:, a] == 2 * row[:, b]
             )
 
         elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"):
-            return (self.col[:, a] - self.col[:, b]) ** 2 + (
-                self.row[:, a] - self.row[:, b]
-            ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + (
-                self.row[:, a] - self.row[:, c]
-            ) ** 2
-
-        elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"):
-            return (self.col[:, a] - self.col[:, b]) ** 2 + (
-                self.row[:, a] - self.row[:, b]
-            ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + (
-                self.row[:, a] - self.row[:, c]
-            ) ** 2
+            return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 == (
+                col[:, a] - col[:, c]
+            ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+        elif match("([1-9]) is_further_from ([1-9]) than_from ([1-9])"):
+            return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 > (
+                col[:, a] - col[:, c]
+            ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+        elif match("([1-9]) is_closer_to ([1-9]) than_to ([1-9])"):
+            return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 < (
+                col[:, a] - col[:, c]
+            ) ** 2 + (row[:, a] - row[:, c]) ** 2
 
         elif match("([1-9]) ([1-9]) ([1-9]) form_a_right_angle"):
-            return (self.col[:, a] - self.col[:, b]) * (
-                self.col[:, c] - self.col[:, b]
-            ) + (self.row[:, a] - self.row[:, b]) * (
-                self.row[:, c] - self.row[:, b]
-            ) == 0
+            return (col[:, a] - col[:, b]) * (col[:, c] - col[:, b]) + (
+                row[:, a] - row[:, b]
+            ) * (row[:, c] - row[:, b]) == 0
 
         else:
             raise ValueError(f"Unknown type of constraint {constraint}")
@@ -184,13 +176,138 @@ class FormalGrid:
                 v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n"
             yield v
 
+    def random_property(self):
+        a, b, c = random.sample(list(range(1, self.nb_symbols + 1)), 3)
+
+        sb, sc = min(b, c), max(b, c)
+
+        ta, tb, tc = sorted((a, b, c))
+
+        l = (
+            [
+                f"{a} is_in_top_half",
+                f"{a} is_in_bottom_half",
+                f"{a} is_on_left_side",
+                f"{a} is_on_right_side",
+            ]
+            + [
+                f"{a} is_below {b}",
+                f"{a} is_above {b}",
+                f"{a} is_left_of {b}",
+                f"{a} is_right_of {b}",
+                f"{sb} next_to {sc}",
+            ]
+            + [
+                f"{sb} {sc} is_parallel_to_diagonal",
+                f"{sb} {sc} is_vertical",
+                f"{sb} {sc} is_horizontal",
+            ]
+            * 2
+            + [
+                f"{ta} {tb} {tc} are_aligned",
+                f"{a} middle_of {sb} {sc}",
+                f"{ta} {tb} {tc} form_a_right_angle",
+            ]
+            * 3
+            + [
+                f"{a} is_equidistant_from {sb} and {sc}",
+                f"{a} is_further_from {b} than_from {c}",
+                f"{a} is_closer_to {b} than_to {c}",
+            ]
+        )
+
+        return random.choice(l)
+
 
 ######################################################################
 
 if __name__ == "__main__":
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device)
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+    elif torch.backends.mps.is_available():
+        device = torch.device("mps")
+    else:
+        device = torch.device("cpu")
+
+    # grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device)
+    # grid_set = grid.new_grid_set(["4 is_equidistant_from 2 and 3", "2 4 is_parallel_to_diagonal"])
+    # print(next(iter(grid.views(grid_set))))
+    # exit(0)
+
+    def proof_depth(steps, c):
+        a = steps.get(c)
+        if a is None:
+            return 0
+        else:
+            c1, c2 = a
+            return max(proof_depth(steps, c1), proof_depth(steps, c2))
+
+    def generate_proof(grid):
+        while True:
+            constraints = [grid.random_property() for _ in range(10)]
+            grid_set = grid.new_grid_set(constraints)
+            if grid_set.any():
+                break
+
+        mg = grid.master_grid_set
+
+        print(constraints)
+
+        initial = constraints.copy()
+
+        steps = {}
+
+        for _ in range(1000):
+            c1, c2 = random.sample(constraints, 2)
+            f1, f2 = grid.constraint_to_fun(c1), grid.constraint_to_fun(c2)
+            for _ in range(100):
+                c = grid.random_property()
+                if c not in constraints:
+                    f = grid.constraint_to_fun(c)
+                    if (
+                        (mg & f1 & ~f).any()
+                        and (mg & f2 & ~f).any()
+                        and (mg & f1 & f2 & f).any()
+                        and not (mg & f1 & f2 & ~f).any()
+                    ):
+                        constraints.append(c)
+                        print(c1, "and", c2, "=>", c)
+                        steps[c] = (c1, c2)
+                        # print(next(iter(grid.views(grid.new_grid_set([c1, c2])))))
+                        # print("we have", constraints)
+                        # proof.append(c1 + " and " + c2 + " hence " + c)
+                        break
+
+            if steps.keys() and max([proof_depth(steps, c) for c in steps.keys()]) >= 3:
+
+                break
+
+        return initial, steps
+
+    grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device)
+
+    initial, steps = generate_proof(grid)
+
+    print(" ; ".join(initial))
+
+    def proof(c, indent=""):
+        a = steps.get(c)
+        if a is None:
+            print(f"{indent}{c} is given")
+        else:
+            print(f"{indent}{c} since")
+            c1, c2 = a
+            proof(c1, indent + "  ")
+            proof(c2, indent + "  ")
+
+    print(" ; ".join(initial))
+
+    for c in steps.keys():
+        proof(c)
+        print()
+
+    exit(0)
 
     # grid_set = grid.new_grid_set(
     # [
@@ -208,6 +325,9 @@ if __name__ == "__main__":
             "2 3 is_parallel_to_diagonal",
             "4 1 is_vertical",
             "3 4 is_horizontal",
+            "3 is_left_of 2",
+            "1 is_below 4",
+            "2 is_right_of 4",
         ],
     )
 
index 23d93b2..001bb81 100755 (executable)
@@ -227,9 +227,9 @@ class PicroCrafterEnvironment:
         r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
 
         q *= self.tile2id["#"]
-        q[
-            torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
-        ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
+        q[torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r] = (
+            torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
+        )
 
         if world_margin > 0:
             r = m.new_full(
index 0baa5a2..806559e 100755 (executable)
--- a/tinyae.py
+++ b/tinyae.py
@@ -92,8 +92,8 @@ class AutoEncoder(nn.Module):
         return self.decoder(z.view(z.size(0), -1, 1, 1))
 
     def forward(self, x):
-        x = self.encoder(x)
-        x = self.decoder(x)
+        x = self.encode(x)
+        x = self.decode(x)
         return x
 
 
index f662be6..19b9387 100755 (executable)
@@ -1,5 +1,8 @@
 #!/usr/bin/env python
 
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_PRE: source ~/venv/pytorch/bin/activate
+
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
 
@@ -14,7 +17,12 @@ lr, nb_epochs, batch_size = 1e-1, 10, 100
 
 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if torch.cuda.is_available():
+    device = torch.device("cuda")
+elif torch.backends.mps.is_available():
+    device = torch.device("mps")
+else:
+    device = torch.device("cpu")
 
 ######################################################################
 
@@ -52,7 +60,7 @@ model = SomeLeNet()
 
 nb_parameters = sum(p.numel() for p in model.parameters())
 
-print(f"nb_parameters {nb_parameters}")
+print(f"device {device} nb_parameters {nb_parameters}")
 
 optimizer = torch.optim.SGD(model.parameters(), lr=lr)
 criterion = nn.CrossEntropyLoss()
@@ -83,17 +91,23 @@ for k in range(nb_epochs):
         loss.backward()
         optimizer.step()
 
+    acc_test_loss = 0.0
     nb_test_errors = 0
     for input, targets in zip(
         test_input.split(batch_size), test_targets.split(batch_size)
     ):
-        wta = model(input).argmax(1)
+        output = model(input)
+        loss = criterion(output, targets)
+        acc_test_loss += loss.item() * input.size(0)
+
+        wta = output.argmax(1)
         nb_test_errors += (wta != targets).long().sum()
+
     test_error = nb_test_errors / test_input.size(0)
     duration = time.perf_counter() - start_time
 
     print(
-        f"loss {k} {duration:.02f}s {acc_train_loss/train_input.size(0):.02f} {test_error*100:.02f}%"
+        f"loss {k} {duration:.02f}s acc_train_loss {acc_train_loss/train_input.size(0):.02f} test_loss {acc_test_loss/test_input.size(0):.02f} test_error {test_error*100:.02f}%"
     )
 
 ######################################################################