Update.
[mygptrnn.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 ######################################################################
13
14
15 class GridFactory:
16     def __init__(
17         self,
18         size=6,
19         max_nb_items=4,
20         max_nb_transformations=3,
21         nb_questions=4,
22         nb_shapes=6,
23         nb_colors=6,
24     ):
25         assert size % 2 == 0
26         self.size = size
27         self.max_nb_items = max_nb_items
28         self.max_nb_transformations = max_nb_transformations
29         self.nb_questions = nb_questions
30         self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)]
31         self.name_colors = [
32             "red",
33             "yellow",
34             "blue",
35             "green",
36             "white",
37             "black",
38             "maroon",
39             "dark_red",
40             "brown",
41             "firebrick",
42             "crimson",
43             "tomato",
44             "coral",
45             "indian_red",
46             "light_coral",
47             "dark_salmon",
48             "salmon",
49             "light_salmon",
50             "orange_red",
51             "dark_orange",
52             "orange",
53             "gold",
54             "dark_golden_rod",
55             "golden_rod",
56             "pale_golden_rod",
57             "dark_khaki",
58             "khaki",
59             "olive",
60             "yellow_green",
61             "dark_olive_green",
62             "olive_drab",
63             "lawn_green",
64             "chartreuse",
65             "green_yellow",
66             "dark_green",
67             "forest_green",
68             "lime",
69             "lime_green",
70             "light_green",
71             "pale_green",
72             "dark_sea_green",
73             "medium_spring_green",
74             "spring_green",
75             "sea_green",
76             "medium_aqua_marine",
77             "medium_sea_green",
78             "light_sea_green",
79             "dark_slate_gray",
80             "teal",
81             "dark_cyan",
82             "aqua",
83             "cyan",
84             "light_cyan",
85             "dark_turquoise",
86             "turquoise",
87             "medium_turquoise",
88             "pale_turquoise",
89             "aqua_marine",
90             "powder_blue",
91             "cadet_blue",
92             "steel_blue",
93             "corn_flower_blue",
94             "deep_sky_blue",
95             "dodger_blue",
96             "light_blue",
97             "sky_blue",
98             "light_sky_blue",
99             "midnight_blue",
100             "navy",
101             "dark_blue",
102             "medium_blue",
103             "royal_blue",
104             "blue_violet",
105             "indigo",
106             "dark_slate_blue",
107             "slate_blue",
108             "medium_slate_blue",
109             "medium_purple",
110             "dark_magenta",
111             "dark_violet",
112             "dark_orchid",
113             "medium_orchid",
114             "purple",
115             "thistle",
116             "plum",
117             "violet",
118             "magenta",
119             "orchid",
120             "medium_violet_red",
121             "pale_violet_red",
122             "deep_pink",
123             "hot_pink",
124             "light_pink",
125             "pink",
126             "antique_white",
127             "beige",
128             "bisque",
129             "blanched_almond",
130             "wheat",
131             "corn_silk",
132             "lemon_chiffon",
133             "light_golden_rod_yellow",
134             "light_yellow",
135             "saddle_brown",
136             "sienna",
137             "chocolate",
138             "peru",
139             "sandy_brown",
140             "burly_wood",
141             "tan",
142             "rosy_brown",
143             "moccasin",
144             "navajo_white",
145             "peach_puff",
146             "misty_rose",
147             "lavender_blush",
148             "linen",
149             "old_lace",
150             "papaya_whip",
151             "sea_shell",
152             "mint_cream",
153             "slate_gray",
154             "light_slate_gray",
155             "light_steel_blue",
156             "lavender",
157             "floral_white",
158             "alice_blue",
159             "ghost_white",
160             "honeydew",
161             "ivory",
162             "azure",
163             "snow",
164             "silver",
165             "gainsboro",
166             "white_smoke",
167         ][:nb_colors]
168
169     def generate_scene(self):
170         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
171         col = torch.full((self.size * self.size,), -1)
172         shp = torch.full((self.size * self.size,), -1)
173         a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
174         col[:nb_items] = a % len(self.name_colors)
175         shp[:nb_items] = a // len(self.name_colors)
176         i = torch.randperm(self.size * self.size)
177         col = col[i]
178         shp = shp[i]
179         return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
180
181     def random_transformations(self, scene):
182         col, shp = scene
183
184         descriptions = []
185         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
186         transformations = torch.randint(5, (nb_transformations,))
187
188         for t in transformations:
189             if t == 0:
190                 col, shp = col.flip(0), shp.flip(0)
191                 descriptions += ["<chg> vertical flip"]
192             elif t == 1:
193                 col, shp = col.flip(1), shp.flip(1)
194                 descriptions += ["<chg> horizontal flip"]
195             elif t == 2:
196                 col, shp = col.flip(0).t(), shp.flip(0).t()
197                 descriptions += ["<chg> rotate 90 degrees"]
198             elif t == 3:
199                 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
200                 descriptions += ["<chg> rotate 180 degrees"]
201             elif t == 4:
202                 col, shp = col.flip(1).t(), shp.flip(1).t()
203                 descriptions += ["<chg> rotate 270 degrees"]
204
205             col, shp = col.contiguous(), shp.contiguous()
206
207         return (col, shp), descriptions
208
209     def print_scene(self, scene):
210         col, shp = scene
211
212         # for i in range(self.size):
213         # for j in range(self.size):
214         # if col[i,j] >= 0:
215         # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
216
217         for i in range(self.size):
218             for j in range(self.size):
219                 if col[i, j] >= 0:
220                     print(
221                         f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
222                         end="",
223                     )
224                 elif j == 0:
225                     print(" +", end="")
226                 else:
227                     print("-+", end="")
228                 if j < self.size - 1:
229                     print("--", end="")
230                 else:
231                     print("")
232             if i < self.size - 1:
233                 for j in range(self.size - 1):
234                     print(" |  ", end="")
235                 print(" |")
236
237     def grid_positions(self, scene):
238         col, shp = scene
239
240         properties = []
241
242         for i in range(self.size):
243             for j in range(self.size):
244                 if col[i, j] >= 0:
245                     n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
246                     properties += [f"a {n} at {i} {j}"]
247
248         return properties
249
250     def all_properties(self, scene):
251         col, shp = scene
252
253         properties = []
254
255         for i1 in range(self.size):
256             for j1 in range(self.size):
257                 if col[i1, j1] >= 0:
258                     n1 = (
259                         f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
260                     )
261                     properties += [f"there is a {n1}"]
262                     if i1 < self.size // 2:
263                         properties += [f"a {n1} is in the top half"]
264                     if i1 >= self.size // 2:
265                         properties += [f"a {n1} is in the bottom half"]
266                     if j1 < self.size // 2:
267                         properties += [f"a {n1} is in the left half"]
268                     if j1 >= self.size // 2:
269                         properties += [f"a {n1} is in the right half"]
270                     for i2 in range(self.size):
271                         for j2 in range(self.size):
272                             if col[i2, j2] >= 0:
273                                 n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
274                                 if i1 > i2:
275                                     properties += [f"a {n1} is below a {n2}"]
276                                 if i1 < i2:
277                                     properties += [f"a {n1} is above a {n2}"]
278                                 if j1 > j2:
279                                     properties += [f"a {n1} is right of a {n2}"]
280                                 if j1 < j2:
281                                     properties += [f"a {n1} is left of a {n2}"]
282                                 if abs(i1 - i2) + abs(j1 - j2) == 1:
283                                     properties += [f"a {n1} is next to a {n2}"]
284
285         return properties
286
287     def generate_scene_and_questions(self):
288         while True:
289             while True:
290                 start_scene = self.generate_scene()
291                 scene, transformations = self.random_transformations(start_scene)
292                 true = self.all_properties(scene)
293                 if len(true) >= self.nb_questions:
294                     break
295
296             for a in range(10):
297                 col, shp = scene
298                 col, shp = col.view(-1), shp.view(-1)
299                 p = torch.randperm(col.size(0))
300                 col, shp = col[p], shp[p]
301                 other_scene = (
302                     col.view(self.size, self.size),
303                     shp.view(self.size, self.size),
304                 )
305
306                 false = self.all_properties(other_scene)
307
308                 # We sometime add properties from a totally different
309                 # scene to have negative "there is a xxx xxx"
310                 # properties
311                 if torch.rand(1).item() < 0.2:
312                     other_scene = self.generate_scene()
313                     false += self.all_properties(other_scene)
314
315                 false = list(set(false) - set(true))
316                 if len(false) >= self.nb_questions:
317                     break
318
319             if a < 10:
320                 break
321
322         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
323         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
324         true = ["<prop> " + q + " <ans> true" for q in true]
325         false = ["<prop> " + q + " <ans> false" for q in false]
326
327         union = true + false
328         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
329
330         result = " ".join(
331             ["<obj> " + x for x in self.grid_positions(start_scene)]
332             + transformations
333             + questions
334         )
335
336         return start_scene, scene, result
337
338     def generate_samples(self, nb, progress_bar=None):
339         result = []
340
341         r = range(nb)
342         if progress_bar is not None:
343             r = progress_bar(r)
344
345         for _ in r:
346             result.append(self.generate_scene_and_questions()[2])
347
348         return result
349
350
351 ######################################################################
352
353 if __name__ == "__main__":
354     import time
355
356     grid_factory = GridFactory()
357
358     # start_time = time.perf_counter()
359     # samples = grid_factory.generate_samples(10000)
360     # end_time = time.perf_counter()
361     # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
362
363     start_scene, scene, questions = grid_factory.generate_scene_and_questions()
364     print()
365     print("-- Original scene -----------------------------")
366     print()
367     grid_factory.print_scene(start_scene)
368     print()
369     print("-- Transformed scene --------------------------")
370     print()
371     grid_factory.print_scene(scene)
372     print()
373     print("-- Sequence -----------------------------------")
374     print()
375     print(questions)
376
377 ######################################################################