From ac3d9ba45d72a7f3e399de4e3614698ac5e0ce39 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jan 2024 14:31:03 +0100 Subject: [PATCH] Update. --- grid.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/grid.py b/grid.py index 268f4ee..2135710 100755 --- a/grid.py +++ b/grid.py @@ -9,10 +9,6 @@ import math import torch, torchvision import torch.nn.functional as F -name_shapes = ["A", "B", "C", "D", "E", "F"] - -name_colors = ["red", "yellow", "blue", "green", "white", "purple"] - ###################################################################### @@ -23,20 +19,24 @@ class GridFactory: max_nb_items=4, max_nb_transformations=3, nb_questions=4, + nb_shapes=6, + nb_colors=6, ): assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions + self.name_shapes = ["A", "B", "C", "D", "E", "F"] + self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"] def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 col = torch.full((self.size * self.size,), -1) shp = torch.full((self.size * self.size,), -1) - a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] - col[:nb_items] = a % len(name_colors) - shp[:nb_items] = a // len(name_colors) + a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items] + col[:nb_items] = a % len(self.name_colors) + shp[:nb_items] = a // len(self.name_colors) i = torch.randperm(self.size * self.size) col = col[i] shp = shp[i] @@ -76,12 +76,15 @@ class GridFactory: # for i in range(self.size): # for j in range(self.size): # if col[i,j] >= 0: - # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}") for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + print( + f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}", + end="", + ) elif j == 0: print(" +", end="") else: @@ -103,7 +106,7 @@ class GridFactory: for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}" properties += [f"a {n} at {i} {j}"] return properties @@ -116,7 +119,9 @@ class GridFactory: for i1 in range(self.size): for j1 in range(self.size): if col[i1, j1] >= 0: - n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + n1 = ( + f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}" + ) properties += [f"there is a {n1}"] if i1 < self.size // 2: properties += [f"a {n1} is in the top half"] @@ -129,7 +134,7 @@ class GridFactory: for i2 in range(self.size): for j2 in range(self.size): if col[i2, j2] >= 0: - n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}" if i1 > i2: properties += [f"a {n1} is below a {n2}"] if i1 < i2: -- 2.20.1