# This code implement a simple system to manipulate formal
# specifications of tokens on a grid.
-import math, re
+import math, re, random
import torch
def constraint_to_fun(self, constraint):
a, b, c = None, None, None
+ col, row = self.col, self.row
def match(pattern):
nonlocal a, b, c
return False
if match("([1-9]) is_in_top_half"):
- return self.row[:, a] < self.grid_height // 2
+ return row[:, a] < self.grid_height // 2
elif match("([1-9]) is_in_bottom_half"):
- return self.row[:, a] >= self.grid_height // 2
+ return row[:, a] >= self.grid_height // 2
elif match("([1-9]) is_on_left_side"):
- return self.col[:, a] < self.grid_width // 2
+ return col[:, a] < self.grid_width // 2
elif match("([1-9]) is_on_right_side"):
- return self.col[:, a] >= self.grid_width // 2
+ return col[:, a] >= self.grid_width // 2
elif match("([1-9]) next_to ([1-9])"):
- return (self.row[:, a] - self.row[:, b]).abs() + (
- self.col[:, a] - self.col[:, b]
- ).abs() <= 1
+ return (row[:, a] - row[:, b]).abs() + (col[:, a] - col[:, b]).abs() == 1
elif match("([1-9]) is_below ([1-9])"):
- return self.row[:, a] > self.row[:, b]
+ return row[:, a] > row[:, b]
elif match("([1-9]) is_above ([1-9])"):
- return self.row[:, a] < self.row[:, b]
+ return row[:, a] < row[:, b]
elif match("([1-9]) is_left_of ([1-9])"):
- return self.col[:, a] < self.col[:, b]
+ return col[:, a] < col[:, b]
elif match("([1-9]) is_right_of ([1-9])"):
- return self.col[:, a] > self.col[:, b]
+ return col[:, a] > col[:, b]
elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"):
- return (self.col[:, a] - self.col[:, b]).abs() == (
- self.row[:, a] - self.row[:, b]
- ).abs()
+ return (col[:, a] - col[:, b]).abs() == (row[:, a] - row[:, b]).abs()
elif match("([1-9]) ([1-9]) is_vertical"):
- return self.col[:, a] == self.col[:, b]
+ return col[:, a] == col[:, b]
elif match("([1-9]) ([1-9]) is_horizontal"):
- return self.row[:, a] == self.row[:, b]
+ return row[:, a] == row[:, b]
elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
- return (self.col[:, a] - self.col[:, b]) * (
- self.row[:, a] - self.row[:, c]
- ) - (self.row[:, a] - self.row[:, b]) * (
- self.col[:, a] - self.col[:, c]
- ) == 0
+ return (col[:, a] - col[:, b]) * (row[:, a] - row[:, c]) - (
+ row[:, a] - row[:, b]
+ ) * (col[:, a] - col[:, c]) == 0
elif match("([1-9]) middle_of ([1-9]) ([1-9])"):
- return (
- grid_set
- & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b])
- & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b])
+ return (col[:, b] + col[:, a] == 2 * col[:, b]) & (
+ row[:, b] + row[:, a] == 2 * row[:, b]
)
elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"):
- return (self.col[:, a] - self.col[:, b]) ** 2 + (
- self.row[:, a] - self.row[:, b]
- ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + (
- self.row[:, a] - self.row[:, c]
- ) ** 2
-
- elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"):
- return (self.col[:, a] - self.col[:, b]) ** 2 + (
- self.row[:, a] - self.row[:, b]
- ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + (
- self.row[:, a] - self.row[:, c]
- ) ** 2
+ return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 == (
+ col[:, a] - col[:, c]
+ ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+ elif match("([1-9]) is_further_from ([1-9]) than_from ([1-9])"):
+ return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 > (
+ col[:, a] - col[:, c]
+ ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+ elif match("([1-9]) is_closer_to ([1-9]) than_to ([1-9])"):
+ return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 < (
+ col[:, a] - col[:, c]
+ ) ** 2 + (row[:, a] - row[:, c]) ** 2
elif match("([1-9]) ([1-9]) ([1-9]) form_a_right_angle"):
- return (self.col[:, a] - self.col[:, b]) * (
- self.col[:, c] - self.col[:, b]
- ) + (self.row[:, a] - self.row[:, b]) * (
- self.row[:, c] - self.row[:, b]
- ) == 0
+ return (col[:, a] - col[:, b]) * (col[:, c] - col[:, b]) + (
+ row[:, a] - row[:, b]
+ ) * (row[:, c] - row[:, b]) == 0
else:
raise ValueError(f"Unknown type of constraint {constraint}")
v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n"
yield v
+ def random_property(self):
+ a, b, c = random.sample(list(range(1, self.nb_symbols + 1)), 3)
+
+ sb, sc = min(b, c), max(b, c)
+
+ ta, tb, tc = sorted((a, b, c))
+
+ l = (
+ [
+ f"{a} is_in_top_half",
+ f"{a} is_in_bottom_half",
+ f"{a} is_on_left_side",
+ f"{a} is_on_right_side",
+ ]
+ + [
+ f"{a} is_below {b}",
+ f"{a} is_above {b}",
+ f"{a} is_left_of {b}",
+ f"{a} is_right_of {b}",
+ f"{sb} next_to {sc}",
+ ]
+ + [
+ f"{sb} {sc} is_parallel_to_diagonal",
+ f"{sb} {sc} is_vertical",
+ f"{sb} {sc} is_horizontal",
+ ]
+ * 2
+ + [
+ f"{ta} {tb} {tc} are_aligned",
+ f"{a} middle_of {sb} {sc}",
+ f"{ta} {tb} {tc} form_a_right_angle",
+ ]
+ * 3
+ + [
+ f"{a} is_equidistant_from {sb} and {sc}",
+ f"{a} is_further_from {b} than_from {c}",
+ f"{a} is_closer_to {b} than_to {c}",
+ ]
+ )
+
+ return random.choice(l)
+
######################################################################
if __name__ == "__main__":
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device)
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ # grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device)
+ # grid_set = grid.new_grid_set(["4 is_equidistant_from 2 and 3", "2 4 is_parallel_to_diagonal"])
+ # print(next(iter(grid.views(grid_set))))
+ # exit(0)
+
+ def proof_depth(steps, c):
+ a = steps.get(c)
+ if a is None:
+ return 0
+ else:
+ c1, c2 = a
+ return max(proof_depth(steps, c1), proof_depth(steps, c2))
+
+ def generate_proof(grid):
+ while True:
+ constraints = [grid.random_property() for _ in range(10)]
+ grid_set = grid.new_grid_set(constraints)
+ if grid_set.any():
+ break
+
+ mg = grid.master_grid_set
+
+ print(constraints)
+
+ initial = constraints.copy()
+
+ steps = {}
+
+ for _ in range(1000):
+ c1, c2 = random.sample(constraints, 2)
+ f1, f2 = grid.constraint_to_fun(c1), grid.constraint_to_fun(c2)
+ for _ in range(100):
+ c = grid.random_property()
+ if c not in constraints:
+ f = grid.constraint_to_fun(c)
+ if (
+ (mg & f1 & ~f).any()
+ and (mg & f2 & ~f).any()
+ and (mg & f1 & f2 & f).any()
+ and not (mg & f1 & f2 & ~f).any()
+ ):
+ constraints.append(c)
+ print(c1, "and", c2, "=>", c)
+ steps[c] = (c1, c2)
+ # print(next(iter(grid.views(grid.new_grid_set([c1, c2])))))
+ # print("we have", constraints)
+ # proof.append(c1 + " and " + c2 + " hence " + c)
+ break
+
+ if steps.keys() and max([proof_depth(steps, c) for c in steps.keys()]) >= 3:
+
+ break
+
+ return initial, steps
+
+ grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device)
+
+ initial, steps = generate_proof(grid)
+
+ print(" ; ".join(initial))
+
+ def proof(c, indent=""):
+ a = steps.get(c)
+ if a is None:
+ print(f"{indent}{c} is given")
+ else:
+ print(f"{indent}{c} since")
+ c1, c2 = a
+ proof(c1, indent + " ")
+ proof(c2, indent + " ")
+
+ print(" ; ".join(initial))
+
+ for c in steps.keys():
+ proof(c)
+ print()
+
+ exit(0)
# grid_set = grid.new_grid_set(
# [
"2 3 is_parallel_to_diagonal",
"4 1 is_vertical",
"3 4 is_horizontal",
+ "3 is_left_of 2",
+ "1 is_below 4",
+ "2 is_right_of 4",
],
)