70f77396f33af428c9ee8f7c1f140a895c474de9
[picoclvr.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
13
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
15
16 ######################################################################
17
18
19 class GridFactory:
20     def __init__(
21         self,
22         height=4,
23         width=4,
24         max_nb_items=4,
25         max_nb_transformations=4,
26         nb_questions=4,
27     ):
28         self.height = height
29         self.width = width
30         self.max_nb_items = max_nb_items
31         self.nb_questions = nb_questions
32
33     def generate_scene(self):
34         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
35         col = torch.full((self.height * self.width,), -1)
36         shp = torch.full((self.height * self.width,), -1)
37         a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
38         col[:nb_items] = a % len(name_colors)
39         shp[:nb_items] = a // len(name_colors)
40         i = torch.randperm(self.height * self.width)
41         col = col[i]
42         shp = shp[i]
43         return col.reshape(self.height, self.width), shp.reshape(
44             self.height, self.width
45         )
46
47     def random_transformations(self):
48         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
49
50     def print_scene(self, scene):
51         col, shp = scene
52
53         # for i in range(self.height):
54         # for j in range(self.width):
55         # if col[i,j] >= 0:
56         # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
57
58         for i in range(self.height):
59             for j in range(self.width):
60                 if col[i, j] >= 0:
61                     print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
62                 elif j == 0:
63                     print(" +", end="")
64                 else:
65                     print("-+", end="")
66                 if j < self.width - 1:
67                     print("--", end="")
68                 else:
69                     print("")
70             if i < self.height - 1:
71                 for j in range(self.width - 1):
72                     print(" |  ", end="")
73                 print(" |")
74
75     def grid_positions(self, scene):
76         col, shp = scene
77
78         properties = []
79
80         for i in range(self.height):
81             for j in range(self.width):
82                 if col[i, j] >= 0:
83                     n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
84                     properties += [f"a {n} at {i} {j}"]
85
86         return properties
87
88     def all_properties(self, scene):
89         col, shp = scene
90
91         properties = []
92
93         for i1 in range(self.height):
94             for j1 in range(self.width):
95                 if col[i1, j1] >= 0:
96                     n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
97                     properties += [f"there is a {n1}"]
98                     if i1 < self.height // 2:
99                         properties += [f"a {n1} is in the top half"]
100                     if i1 >= self.height // 2:
101                         properties += [f"a {n1} is in the bottom half"]
102                     if j1 < self.width // 2:
103                         properties += [f"a {n1} is in the left half"]
104                     if j1 >= self.width // 2:
105                         properties += [f"a {n1} is in the right half"]
106                     for i2 in range(self.height):
107                         for j2 in range(self.width):
108                             if col[i2, j2] >= 0:
109                                 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
110                                 if i1 > i2:
111                                     properties += [f"a {n1} is below a {n2}"]
112                                 if i1 < i2:
113                                     properties += [f"a {n1} is above a {n2}"]
114                                 if j1 > j2:
115                                     properties += [f"a {n1} is right of a {n2}"]
116                                 if j1 < j2:
117                                     properties += [f"a {n1} is left of a {n2}"]
118
119         return properties
120
121     def generate_scene_and_questions(self):
122         while True:
123             while True:
124                 scene = self.generate_scene()
125                 true = self.all_properties(scene)
126                 if len(true) >= self.nb_questions:
127                     break
128
129             start = self.grid_positions(scene)
130
131             for a in range(10):
132                 col, shp = scene
133                 col, shp = col.view(-1), shp.view(-1)
134                 p = torch.randperm(col.size(0))
135                 col, shp = col[p], shp[p]
136                 other_scene = (
137                     col.view(self.height, self.width),
138                     shp.view(self.height, self.width),
139                 )
140                 # other_scene = self.generate_scene()
141                 false = list(set(self.all_properties(other_scene)) - set(true))
142                 if len(false) >= self.nb_questions:
143                     break
144
145             # print(f"{a=}")
146
147             if a < 10:
148                 break
149
150         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
151         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
152         true = ["<prop> " + q + " <true>" for q in true]
153         false = ["<prop> " + q + " <false>" for q in false]
154
155         union = true + false
156         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
157
158         result = " ".join(
159             ["<obj> " + x for x in self.grid_positions(scene)] + questions
160         )
161
162         return scene, result
163
164     def generate_samples(self, nb, progress_bar=None):
165         result = []
166
167         r = range(nb)
168         if progress_bar is not None:
169             r = progress_bar(r)
170
171         for _ in r:
172             result.append(self.generate_scene_and_questions()[1])
173
174         return result
175
176
177 ######################################################################
178
179 if __name__ == "__main__":
180     import time
181
182     grid_factory = GridFactory()
183
184     start_time = time.perf_counter()
185     samples = grid_factory.generate_samples(10000)
186     end_time = time.perf_counter()
187     print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
188
189     scene, questions = grid_factory.generate_scene_and_questions()
190     grid_factory.print_scene(scene)
191     print(questions)
192
193 ######################################################################