Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 17:21:50 +0000 (19:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 17:21:50 +0000 (19:21 +0200)
grid.py
main.py
tasks.py

diff --git a/grid.py b/grid.py
index 433cfd5..f72c8e3 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -19,34 +19,31 @@ name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
 class GridFactory:
     def __init__(
         self,
-        height=4,
-        width=4,
+        size=4,
         max_nb_items=4,
-        max_nb_transformations=4,
+        max_nb_transformations=3,
         nb_questions=4,
     ):
-        self.height = height
-        self.width = width
+        self.size = size
         self.max_nb_items = max_nb_items
         self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
 
     def generate_scene(self):
         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
-        col = torch.full((self.height * self.width,), -1)
-        shp = torch.full((self.height * self.width,), -1)
+        col = torch.full((self.size * self.size,), -1)
+        shp = torch.full((self.size * self.size,), -1)
         a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
         col[:nb_items] = a % len(name_colors)
         shp[:nb_items] = a // len(name_colors)
-        i = torch.randperm(self.height * self.width)
+        i = torch.randperm(self.size * self.size)
         col = col[i]
         shp = shp[i]
-        return col.reshape(self.height, self.width), shp.reshape(
-            self.height, self.width
-        )
+        return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
 
     def random_transformations(self, scene):
         col, shp = scene
+
         descriptions = []
         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
         transformations = torch.randint(5, (nb_transformations,))
@@ -68,30 +65,32 @@ class GridFactory:
                 col, shp = col.flip(1).t(), shp.flip(1).t()
                 descriptions += ["<chg> rotate 270 degrees"]
 
-        return (col.contiguous(), shp.contiguous()), descriptions
+            col, shp = col.contiguous(), shp.contiguous()
+
+        return (col, shp), descriptions
 
     def print_scene(self, scene):
         col, shp = scene
 
-        # for i in range(self.height):
-        # for j in range(self.width):
+        # for i in range(self.size):
+        # for j in range(self.size):
         # if col[i,j] >= 0:
         # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
 
-        for i in range(self.height):
-            for j in range(self.width):
+        for i in range(self.size):
+            for j in range(self.size):
                 if col[i, j] >= 0:
                     print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
                 elif j == 0:
                     print(" +", end="")
                 else:
                     print("-+", end="")
-                if j < self.width - 1:
+                if j < self.size - 1:
                     print("--", end="")
                 else:
                     print("")
-            if i < self.height - 1:
-                for j in range(self.width - 1):
+            if i < self.size - 1:
+                for j in range(self.size - 1):
                     print(" |  ", end="")
                 print(" |")
 
@@ -100,8 +99,8 @@ class GridFactory:
 
         properties = []
 
-        for i in range(self.height):
-            for j in range(self.width):
+        for i in range(self.size):
+            for j in range(self.size):
                 if col[i, j] >= 0:
                     n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
                     properties += [f"a {n} at {i} {j}"]
@@ -113,21 +112,21 @@ class GridFactory:
 
         properties = []
 
-        for i1 in range(self.height):
-            for j1 in range(self.width):
+        for i1 in range(self.size):
+            for j1 in range(self.size):
                 if col[i1, j1] >= 0:
                     n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
                     properties += [f"there is a {n1}"]
-                    if i1 < self.height // 2:
+                    if i1 < self.size // 2:
                         properties += [f"a {n1} is in the top half"]
-                    if i1 >= self.height // 2:
+                    if i1 >= self.size // 2:
                         properties += [f"a {n1} is in the bottom half"]
-                    if j1 < self.width // 2:
+                    if j1 < self.size // 2:
                         properties += [f"a {n1} is in the left half"]
-                    if j1 >= self.width // 2:
+                    if j1 >= self.size // 2:
                         properties += [f"a {n1} is in the right half"]
-                    for i2 in range(self.height):
-                        for j2 in range(self.width):
+                    for i2 in range(self.size):
+                        for j2 in range(self.size):
                             if col[i2, j2] >= 0:
                                 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
                                 if i1 > i2:
@@ -153,22 +152,22 @@ class GridFactory:
 
             scene, transformations = self.random_transformations(scene)
 
+            # transformations=[]
+
             for a in range(10):
                 col, shp = scene
                 col, shp = col.view(-1), shp.view(-1)
                 p = torch.randperm(col.size(0))
                 col, shp = col[p], shp[p]
                 other_scene = (
-                    col.view(self.height, self.width),
-                    shp.view(self.height, self.width),
+                    col.view(self.size, self.size),
+                    shp.view(self.size, self.size),
                 )
                 # other_scene = self.generate_scene()
                 false = list(set(self.all_properties(other_scene)) - set(true))
                 if len(false) >= self.nb_questions:
                     break
 
-            # print(f"{a=}")
-
             if a < 10:
                 break
 
diff --git a/main.py b/main.py
index 00e19ac..704dff5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -99,6 +99,11 @@ parser.add_argument("--rpl_nb_runs", type=int, default=5)
 
 parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
+##############################
+# grid options
+
+parser.add_argument("--grid_size", type=int, default=6)
+
 ##############################
 # picoclvr options
 
@@ -517,8 +522,7 @@ elif args.task == "grid":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
-        height=args.picoclvr_height,
-        width=args.picoclvr_width,
+        size=args.grid_size,
         logger=log_string,
         device=device,
     )
index 0ab1823..2c2f914 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1459,8 +1459,7 @@ class Grid(Task):
         nb_train_samples,
         nb_test_samples,
         batch_size,
-        height,
-        width,
+        size,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1468,7 +1467,7 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(height=height, width=width)
+        self.grid_factory = grid.GridFactory(size=size)
 
         if logger is not None:
             logger(