Update.
[mygptrnn.git] / grid.py
diff --git a/grid.py b/grid.py
index 268f4ee..f9f1557 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -9,10 +9,6 @@ import math
 import torch, torchvision
 import torch.nn.functional as F
 
-name_shapes = ["A", "B", "C", "D", "E", "F"]
-
-name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-
 ######################################################################
 
 
@@ -23,20 +19,160 @@ class GridFactory:
         max_nb_items=4,
         max_nb_transformations=3,
         nb_questions=4,
+        nb_shapes=6,
+        nb_colors=6,
     ):
         assert size % 2 == 0
         self.size = size
         self.max_nb_items = max_nb_items
         self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
+        self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)]
+        self.name_colors = [
+            "red",
+            "yellow",
+            "blue",
+            "green",
+            "white",
+            "black",
+            "maroon",
+            "dark_red",
+            "brown",
+            "firebrick",
+            "crimson",
+            "tomato",
+            "coral",
+            "indian_red",
+            "light_coral",
+            "dark_salmon",
+            "salmon",
+            "light_salmon",
+            "orange_red",
+            "dark_orange",
+            "orange",
+            "gold",
+            "dark_golden_rod",
+            "golden_rod",
+            "pale_golden_rod",
+            "dark_khaki",
+            "khaki",
+            "olive",
+            "yellow_green",
+            "dark_olive_green",
+            "olive_drab",
+            "lawn_green",
+            "chartreuse",
+            "green_yellow",
+            "dark_green",
+            "forest_green",
+            "lime",
+            "lime_green",
+            "light_green",
+            "pale_green",
+            "dark_sea_green",
+            "medium_spring_green",
+            "spring_green",
+            "sea_green",
+            "medium_aqua_marine",
+            "medium_sea_green",
+            "light_sea_green",
+            "dark_slate_gray",
+            "teal",
+            "dark_cyan",
+            "aqua",
+            "cyan",
+            "light_cyan",
+            "dark_turquoise",
+            "turquoise",
+            "medium_turquoise",
+            "pale_turquoise",
+            "aqua_marine",
+            "powder_blue",
+            "cadet_blue",
+            "steel_blue",
+            "corn_flower_blue",
+            "deep_sky_blue",
+            "dodger_blue",
+            "light_blue",
+            "sky_blue",
+            "light_sky_blue",
+            "midnight_blue",
+            "navy",
+            "dark_blue",
+            "medium_blue",
+            "royal_blue",
+            "blue_violet",
+            "indigo",
+            "dark_slate_blue",
+            "slate_blue",
+            "medium_slate_blue",
+            "medium_purple",
+            "dark_magenta",
+            "dark_violet",
+            "dark_orchid",
+            "medium_orchid",
+            "purple",
+            "thistle",
+            "plum",
+            "violet",
+            "magenta",
+            "orchid",
+            "medium_violet_red",
+            "pale_violet_red",
+            "deep_pink",
+            "hot_pink",
+            "light_pink",
+            "pink",
+            "antique_white",
+            "beige",
+            "bisque",
+            "blanched_almond",
+            "wheat",
+            "corn_silk",
+            "lemon_chiffon",
+            "light_golden_rod_yellow",
+            "light_yellow",
+            "saddle_brown",
+            "sienna",
+            "chocolate",
+            "peru",
+            "sandy_brown",
+            "burly_wood",
+            "tan",
+            "rosy_brown",
+            "moccasin",
+            "navajo_white",
+            "peach_puff",
+            "misty_rose",
+            "lavender_blush",
+            "linen",
+            "old_lace",
+            "papaya_whip",
+            "sea_shell",
+            "mint_cream",
+            "slate_gray",
+            "light_slate_gray",
+            "light_steel_blue",
+            "lavender",
+            "floral_white",
+            "alice_blue",
+            "ghost_white",
+            "honeydew",
+            "ivory",
+            "azure",
+            "snow",
+            "silver",
+            "gainsboro",
+            "white_smoke",
+        ][:nb_colors]
 
     def generate_scene(self):
         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
         col = torch.full((self.size * self.size,), -1)
         shp = torch.full((self.size * self.size,), -1)
-        a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
-        col[:nb_items] = a % len(name_colors)
-        shp[:nb_items] = a // len(name_colors)
+        a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
+        col[:nb_items] = a % len(self.name_colors)
+        shp[:nb_items] = a // len(self.name_colors)
         i = torch.randperm(self.size * self.size)
         col = col[i]
         shp = shp[i]
@@ -76,12 +212,15 @@ class GridFactory:
         # for i in range(self.size):
         # for j in range(self.size):
         # if col[i,j] >= 0:
-        # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+        # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
 
         for i in range(self.size):
             for j in range(self.size):
                 if col[i, j] >= 0:
-                    print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+                    print(
+                        f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
+                        end="",
+                    )
                 elif j == 0:
                     print(" +", end="")
                 else:
@@ -103,7 +242,7 @@ class GridFactory:
         for i in range(self.size):
             for j in range(self.size):
                 if col[i, j] >= 0:
-                    n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+                    n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
                     properties += [f"a {n} at {i} {j}"]
 
         return properties
@@ -116,7 +255,9 @@ class GridFactory:
         for i1 in range(self.size):
             for j1 in range(self.size):
                 if col[i1, j1] >= 0:
-                    n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+                    n1 = (
+                        f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
+                    )
                     properties += [f"there is a {n1}"]
                     if i1 < self.size // 2:
                         properties += [f"a {n1} is in the top half"]
@@ -129,7 +270,7 @@ class GridFactory:
                     for i2 in range(self.size):
                         for j2 in range(self.size):
                             if col[i2, j2] >= 0:
-                                n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+                                n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
                                 if i1 > i2:
                                     properties += [f"a {n1} is below a {n2}"]
                                 if i1 < i2: