Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 24 Aug 2023 20:34:32 +0000 (22:34 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 24 Aug 2023 20:34:32 +0000 (22:34 +0200)
grid.py [new file with mode: 0755]
main.py

diff --git a/grid.py b/grid.py
new file mode 100755 (executable)
index 0000000..08ddc23
--- /dev/null
+++ b/grid.py
@@ -0,0 +1,167 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+import torch, torchvision
+import torch.nn.functional as F
+
+name_shapes = ["A", "B", "C", "D", "E", "F"]
+
+name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
+
+######################################################################
+
+
+class GridFactory:
+    def __init__(
+        self,
+        height=4,
+        width=4,
+        max_nb_items=4,
+        max_nb_transformations=4,
+        nb_questions=4,
+    ):
+        self.height = height
+        self.width = width
+        self.max_nb_items = max_nb_items
+        self.nb_questions = nb_questions
+
+    def generate_scene(self):
+        nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
+        col = torch.full((self.height * self.width,), -1)
+        shp = torch.full((self.height * self.width,), -1)
+        a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
+        col[:nb_items] = a % len(name_colors)
+        shp[:nb_items] = a // len(name_colors)
+        i = torch.randperm(self.height * self.width)
+        col = col[i]
+        shp = shp[i]
+        return col.reshape(self.height, self.width), shp.reshape(
+            self.height, self.width
+        )
+
+    def random_transformations(self):
+        nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+
+    def print_scene(self, scene):
+        col, shp = scene
+
+        # for i in range(self.height):
+        # for j in range(self.width):
+        # if col[i,j] >= 0:
+        # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+
+        for i in range(self.height):
+            for j in range(self.width):
+                if col[i, j] >= 0:
+                    print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+                elif j == 0:
+                    print(" +", end="")
+                else:
+                    print("-+", end="")
+                if j < self.width - 1:
+                    print("--", end="")
+                else:
+                    print("")
+            if i < self.height - 1:
+                for j in range(self.width - 1):
+                    print(" |  ", end="")
+                print(" |")
+
+    def grid_positions(self, scene):
+        col, shp = scene
+
+        properties = []
+
+        for i in range(self.height):
+            for j in range(self.width):
+                if col[i, j] >= 0:
+                    n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+                    properties += [f"a {n} at {i} {j}"]
+
+        return properties
+
+    def all_properties(self, scene):
+        col, shp = scene
+
+        properties = []
+
+        for i1 in range(self.height):
+            for j1 in range(self.width):
+                if col[i1, j1] >= 0:
+                    n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+                    properties += [f"there is a {n1}"]
+                    if i1 < self.height // 2:
+                        properties += [f"a {n1} is in the top half"]
+                    if i1 >= self.height // 2:
+                        properties += [f"a {n1} is in the bottom half"]
+                    if j1 < self.width // 2:
+                        properties += [f"a {n1} is in the left half"]
+                    if j1 >= self.width // 2:
+                        properties += [f"a {n1} is in the right half"]
+                    for i2 in range(self.height):
+                        for j2 in range(self.width):
+                            if col[i2, j2] >= 0:
+                                n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+                                if i1 > i2:
+                                    properties += [f"a {n1} is below a {n2}"]
+                                if i1 < i2:
+                                    properties += [f"a {n1} is above a {n2}"]
+                                if j1 > j2:
+                                    properties += [f"a {n1} is right of a {n2}"]
+                                if j1 < j2:
+                                    properties += [f"a {n1} is left of a {n2}"]
+
+        return properties
+
+    def generate_example(self):
+        while True:
+            while True:
+                scene = self.generate_scene()
+                true = self.all_properties(scene)
+                if len(true) >= self.nb_questions:
+                    break
+
+            start = self.grid_positions(scene)
+
+            for a in range(10):
+                col, shp = scene
+                col, shp = col.view(-1), shp.view(-1)
+                p = torch.randperm(col.size(0))
+                col, shp = col[p], shp[p]
+                other_scene = (
+                    col.view(self.height, self.width),
+                    shp.view(self.height, self.width),
+                )
+                # other_scene = self.generate_scene()
+                false = list(set(self.all_properties(other_scene)) - set(true))
+                if len(false) >= self.nb_questions:
+                    break
+
+            if a < 10:
+                break
+
+        true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
+        false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
+        true = [(q, "yes") for q in true]
+        false = [(q, "no") for q in false]
+
+        union = true + false
+        questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
+
+        return scene, questions
+
+
+######################################################################
+
+if __name__ == "__main__":
+    grid_factory = GridFactory()
+    scene, questions = grid_factory.generate_example()
+    grid_factory.print_scene(scene)
+    print(questions)
+
+######################################################################
diff --git a/main.py b/main.py
index 8081850..ff831f4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -89,13 +89,13 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # rpl options
 
-parser.add_argument("--rpl_nb_starting_values", type=int, default=5)
+parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
 
 parser.add_argument("--rpl_max_input", type=int, default=9)
 
-parser.add_argument("--rpl_prog_len", type=int, default=10)
+parser.add_argument("--rpl_prog_len", type=int, default=8)
 
-parser.add_argument("--rpl_nb_runs", type=int, default=8)
+parser.add_argument("--rpl_nb_runs", type=int, default=5)
 
 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
@@ -249,10 +249,10 @@ default_task_args = {
         "nb_test_samples": 10000,
     },
     "rpl": {
-        "model": "352M",
+        "model": "122M",
         "nb_epochs": 50,
-        "batch_size": 10,
-        "nb_train_samples": 2500000,
+        "batch_size": 5,
+        "nb_train_samples": 1000000,
         "nb_test_samples": 10000,
     },
     "world": {