Oups
[picoclvr.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 ######################################################################
13
14
15 class GridFactory:
16     def __init__(
17         self,
18         size=6,
19         max_nb_items=4,
20         max_nb_transformations=3,
21         nb_questions=4,
22         nb_shapes=6,
23         nb_colors=6,
24         nb_play_steps=3,
25     ):
26         assert size % 2 == 0
27         self.size = size
28         self.max_nb_items = max_nb_items
29         self.max_nb_transformations = max_nb_transformations
30         self.nb_questions = nb_questions
31         self.nb_play_steps = nb_play_steps
32         self.name_shapes = ["A", "B", "C", "D", "E", "F"]
33         self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
34         self.vname_shapes = ["vA", "vB", "vC", "vD", "vE", "vF"]
35         self.vname_colors = ["vred", "vyellow", "vblue", "vgreen", "vwhite", "vpurple"]
36
37     def generate_scene(self):
38         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
39         col = torch.full((self.size * self.size,), -1)
40         shp = torch.full((self.size * self.size,), -1)
41         a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
42         col[:nb_items] = a % len(self.name_colors)
43         shp[:nb_items] = a // len(self.name_colors)
44         i = torch.randperm(self.size * self.size)
45         col = col[i]
46         shp = shp[i]
47         return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
48
49     def random_object_move(self, scene):
50         col, shp = scene
51         while True:
52             a = (col.flatten() >= 0).nonzero()
53             a = a[torch.randint(a.size(0), (1,)).item()]
54             i, j = a // self.size, a % self.size
55             assert col[i, j] >= 0
56             dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
57             dst = list(
58                 filter(
59                     lambda x: x[0] >= 0
60                     and x[1] >= 0
61                     and x[0] < self.size
62                     and x[1] < self.size
63                     and col[x[0], x[1]] < 0,
64                     dst,
65                 )
66             )
67             if len(dst) > 0:
68                 ni, nj = dst[torch.randint(len(dst), (1,)).item()]
69                 col[ni, nj] = col[i, j]
70                 shp[ni, nj] = shp[i, j]
71                 col[i, j] = -1
72                 shp[i, j] = -1
73                 break
74
75         return col, shp
76
77     def transformation(self, t, scene):
78         col, shp = scene
79         if t == 0:
80             col, shp = col.flip(0), shp.flip(0)
81             description = "<chg> vertical flip"
82         elif t == 1:
83             col, shp = col.flip(1), shp.flip(1)
84             description = "<chg> horizontal flip"
85         elif t == 2:
86             col, shp = col.flip(0).t(), shp.flip(0).t()
87             description = "<chg> rotate 90 degrees"
88         elif t == 3:
89             col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
90             description = "<chg> rotate 180 degrees"
91         elif t == 4:
92             col, shp = col.flip(1).t(), shp.flip(1).t()
93             description = "<chg> rotate 270 degrees"
94
95         return (col.contiguous(), shp.contiguous()), description
96
97     def random_transformations(self, scene):
98         descriptions = []
99         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
100         transformations = torch.randint(5, (nb_transformations,))
101
102         for t in transformations:
103             scene, description = self.transformation(t, scene)
104             descriptions += [description]
105
106         return scene, descriptions
107
108     def visual_scene2str(self, scene):
109         col, shp = scene
110         r = []
111         for i in range(self.size):
112             s = []
113             for j in range(self.size):
114                 if col[i, j] >= 0:
115                     s += [self.vname_colors[col[i, j]], self.vname_shapes[shp[i, j]]]
116                 else:
117                     s += ["v_", "v+"]
118             r += s  # .append(" ".join(s))
119         return " ".join(r)
120
121     def print_scene(self, scene):
122         col, shp = scene
123
124         # for i in range(self.size):
125         # for j in range(self.size):
126         # if col[i,j] >= 0:
127         # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
128
129         for i in range(self.size):
130             for j in range(self.size):
131                 if col[i, j] >= 0:
132                     print(
133                         f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
134                         end="",
135                     )
136                 elif j == 0:
137                     print(" +", end="")
138                 else:
139                     print("-+", end="")
140                 if j < self.size - 1:
141                     print("--", end="")
142                 else:
143                     print("")
144             if i < self.size - 1:
145                 for j in range(self.size - 1):
146                     print(" |  ", end="")
147                 print(" |")
148
149     def grid_positions(self, scene):
150         col, shp = scene
151
152         properties = []
153
154         for i in range(self.size):
155             for j in range(self.size):
156                 if col[i, j] >= 0:
157                     n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
158                     properties += [f"a {n} at {i} {j}"]
159
160         return properties
161
162     def all_properties(self, scene):
163         col, shp = scene
164
165         properties = []
166
167         for i1 in range(self.size):
168             for j1 in range(self.size):
169                 if col[i1, j1] >= 0:
170                     n1 = (
171                         f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
172                     )
173                     properties += [f"there is a {n1}"]
174                     if i1 < self.size // 2:
175                         properties += [f"a {n1} is in the top half"]
176                     if i1 >= self.size // 2:
177                         properties += [f"a {n1} is in the bottom half"]
178                     if j1 < self.size // 2:
179                         properties += [f"a {n1} is in the left half"]
180                     if j1 >= self.size // 2:
181                         properties += [f"a {n1} is in the right half"]
182                     for i2 in range(self.size):
183                         for j2 in range(self.size):
184                             if col[i2, j2] >= 0:
185                                 n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
186                                 if i1 > i2:
187                                     properties += [f"a {n1} is below a {n2}"]
188                                 if i1 < i2:
189                                     properties += [f"a {n1} is above a {n2}"]
190                                 if j1 > j2:
191                                     properties += [f"a {n1} is right of a {n2}"]
192                                 if j1 < j2:
193                                     properties += [f"a {n1} is left of a {n2}"]
194                                 if abs(i1 - i2) + abs(j1 - j2) == 1:
195                                     properties += [f"a {n1} is next to a {n2}"]
196
197         return properties
198
199     def generate_scene_and_play(self):
200         scene = self.generate_scene()
201         steps = [self.visual_scene2str(scene)]
202         for t in range(self.nb_play_steps - 1):
203             if torch.randint(4, (1,)).item() == 0:
204                 scene, _ = self.transformation(torch.randint(5, (1,)), scene)
205             else:
206                 scene = self.random_object_move(scene)
207             steps.append(self.visual_scene2str(scene))
208         return " | ".join(steps)
209
210     def generate_scene_and_questions(self):
211         while True:
212             # We generate scenes until we get one with enough
213             # properties
214
215             while True:
216                 start_scene = self.generate_scene()
217                 scene, transformations = self.random_transformations(start_scene)
218                 true = self.all_properties(scene)
219                 if len(true) >= self.nb_questions:
220                     break
221
222             # We generate a bunch of false properties by shuffling the
223             # scene and sometimes adding properties from totally
224             # different scenes. We try ten times to get enough false
225             # properties and go back to generating the scene if we do
226             # not succeed
227
228             for a in range(10):
229                 col, shp = scene
230                 col, shp = col.view(-1), shp.view(-1)
231                 p = torch.randperm(col.size(0))
232                 col, shp = col[p], shp[p]
233                 other_scene = (
234                     col.view(self.size, self.size),
235                     shp.view(self.size, self.size),
236                 )
237
238                 false = self.all_properties(other_scene)
239
240                 # We sometime add properties from a totally different
241                 # scene to have negative "there is a xxx xxx"
242                 # properties
243
244                 if torch.rand(1).item() < 0.2:
245                     other_scene = self.generate_scene()
246                     false += self.all_properties(other_scene)
247
248                 false = list(set(false) - set(true))
249                 if len(false) >= self.nb_questions:
250                     break
251
252             if a < 10:
253                 break
254
255         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
256         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
257         true = ["<prop> " + q + " <ans> true" for q in true]
258         false = ["<prop> " + q + " <ans> false" for q in false]
259
260         union = true + false
261         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
262
263         result = " ".join(
264             ["<obj> " + x for x in self.grid_positions(start_scene)]
265             + transformations
266             + questions
267         )
268
269         return start_scene, scene, result
270
271     def generate_samples(self, nb, fraction_play=0.0, progress_bar=None):
272         result = []
273
274         play = torch.rand(nb) < fraction_play
275         if progress_bar is not None:
276             play = progress_bar(play)
277
278         for p in play:
279             if p:
280                 result.append(self.generate_scene_and_play())
281             else:
282                 result.append(self.generate_scene_and_questions()[2])
283
284         return result
285
286
287 ######################################################################
288
289 if __name__ == "__main__":
290     import time
291
292     grid_factory = GridFactory()
293
294     # start_time = time.perf_counter()
295     # samples = grid_factory.generate_samples(10000)
296     # end_time = time.perf_counter()
297     # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
298
299     start_scene, scene, questions = grid_factory.generate_scene_and_questions()
300     print()
301     print("-- Original scene -----------------------------")
302     print()
303     grid_factory.print_scene(start_scene)
304     print()
305     print("-- Transformed scene --------------------------")
306     print()
307     grid_factory.print_scene(scene)
308     print()
309     print("-- Sequence -----------------------------------")
310     print()
311     print(questions)
312
313     # print(grid_factory.visual_scene2str(scene))
314
315     # grid_factory.print_scene(scene)
316     # for t in range(5):
317     # scene = grid_factory.random_object_move(scene)
318     # print()
319     # grid_factory.print_scene(scene)
320
321     print(grid_factory.generate_scene_and_play())
322
323 ######################################################################