268f4eed9ba64a57899a575828ba98cde55e1e13
[picoclvr.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
13
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
15
16 ######################################################################
17
18
19 class GridFactory:
20     def __init__(
21         self,
22         size=6,
23         max_nb_items=4,
24         max_nb_transformations=3,
25         nb_questions=4,
26     ):
27         assert size % 2 == 0
28         self.size = size
29         self.max_nb_items = max_nb_items
30         self.max_nb_transformations = max_nb_transformations
31         self.nb_questions = nb_questions
32
33     def generate_scene(self):
34         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
35         col = torch.full((self.size * self.size,), -1)
36         shp = torch.full((self.size * self.size,), -1)
37         a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
38         col[:nb_items] = a % len(name_colors)
39         shp[:nb_items] = a // len(name_colors)
40         i = torch.randperm(self.size * self.size)
41         col = col[i]
42         shp = shp[i]
43         return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
44
45     def random_transformations(self, scene):
46         col, shp = scene
47
48         descriptions = []
49         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
50         transformations = torch.randint(5, (nb_transformations,))
51
52         for t in transformations:
53             if t == 0:
54                 col, shp = col.flip(0), shp.flip(0)
55                 descriptions += ["<chg> vertical flip"]
56             elif t == 1:
57                 col, shp = col.flip(1), shp.flip(1)
58                 descriptions += ["<chg> horizontal flip"]
59             elif t == 2:
60                 col, shp = col.flip(0).t(), shp.flip(0).t()
61                 descriptions += ["<chg> rotate 90 degrees"]
62             elif t == 3:
63                 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
64                 descriptions += ["<chg> rotate 180 degrees"]
65             elif t == 4:
66                 col, shp = col.flip(1).t(), shp.flip(1).t()
67                 descriptions += ["<chg> rotate 270 degrees"]
68
69             col, shp = col.contiguous(), shp.contiguous()
70
71         return (col, shp), descriptions
72
73     def print_scene(self, scene):
74         col, shp = scene
75
76         # for i in range(self.size):
77         # for j in range(self.size):
78         # if col[i,j] >= 0:
79         # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
80
81         for i in range(self.size):
82             for j in range(self.size):
83                 if col[i, j] >= 0:
84                     print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
85                 elif j == 0:
86                     print(" +", end="")
87                 else:
88                     print("-+", end="")
89                 if j < self.size - 1:
90                     print("--", end="")
91                 else:
92                     print("")
93             if i < self.size - 1:
94                 for j in range(self.size - 1):
95                     print(" |  ", end="")
96                 print(" |")
97
98     def grid_positions(self, scene):
99         col, shp = scene
100
101         properties = []
102
103         for i in range(self.size):
104             for j in range(self.size):
105                 if col[i, j] >= 0:
106                     n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
107                     properties += [f"a {n} at {i} {j}"]
108
109         return properties
110
111     def all_properties(self, scene):
112         col, shp = scene
113
114         properties = []
115
116         for i1 in range(self.size):
117             for j1 in range(self.size):
118                 if col[i1, j1] >= 0:
119                     n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
120                     properties += [f"there is a {n1}"]
121                     if i1 < self.size // 2:
122                         properties += [f"a {n1} is in the top half"]
123                     if i1 >= self.size // 2:
124                         properties += [f"a {n1} is in the bottom half"]
125                     if j1 < self.size // 2:
126                         properties += [f"a {n1} is in the left half"]
127                     if j1 >= self.size // 2:
128                         properties += [f"a {n1} is in the right half"]
129                     for i2 in range(self.size):
130                         for j2 in range(self.size):
131                             if col[i2, j2] >= 0:
132                                 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
133                                 if i1 > i2:
134                                     properties += [f"a {n1} is below a {n2}"]
135                                 if i1 < i2:
136                                     properties += [f"a {n1} is above a {n2}"]
137                                 if j1 > j2:
138                                     properties += [f"a {n1} is right of a {n2}"]
139                                 if j1 < j2:
140                                     properties += [f"a {n1} is left of a {n2}"]
141                                 if abs(i1 - i2) + abs(j1 - j2) == 1:
142                                     properties += [f"a {n1} is next to a {n2}"]
143
144         return properties
145
146     def generate_scene_and_questions(self):
147         while True:
148             while True:
149                 start_scene = self.generate_scene()
150                 scene, transformations = self.random_transformations(start_scene)
151                 true = self.all_properties(scene)
152                 if len(true) >= self.nb_questions:
153                     break
154
155             for a in range(10):
156                 col, shp = scene
157                 col, shp = col.view(-1), shp.view(-1)
158                 p = torch.randperm(col.size(0))
159                 col, shp = col[p], shp[p]
160                 other_scene = (
161                     col.view(self.size, self.size),
162                     shp.view(self.size, self.size),
163                 )
164
165                 false = self.all_properties(other_scene)
166
167                 # We sometime add properties from a totally different
168                 # scene to have negative "there is a xxx xxx"
169                 # properties
170                 if torch.rand(1).item() < 0.2:
171                     other_scene = self.generate_scene()
172                     false += self.all_properties(other_scene)
173
174                 false = list(set(false) - set(true))
175                 if len(false) >= self.nb_questions:
176                     break
177
178             if a < 10:
179                 break
180
181         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
182         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
183         true = ["<prop> " + q + " <ans> true" for q in true]
184         false = ["<prop> " + q + " <ans> false" for q in false]
185
186         union = true + false
187         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
188
189         result = " ".join(
190             ["<obj> " + x for x in self.grid_positions(start_scene)]
191             + transformations
192             + questions
193         )
194
195         return start_scene, scene, result
196
197     def generate_samples(self, nb, progress_bar=None):
198         result = []
199
200         r = range(nb)
201         if progress_bar is not None:
202             r = progress_bar(r)
203
204         for _ in r:
205             result.append(self.generate_scene_and_questions()[2])
206
207         return result
208
209
210 ######################################################################
211
212 if __name__ == "__main__":
213     import time
214
215     grid_factory = GridFactory()
216
217     # start_time = time.perf_counter()
218     # samples = grid_factory.generate_samples(10000)
219     # end_time = time.perf_counter()
220     # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
221
222     start_scene, scene, questions = grid_factory.generate_scene_and_questions()
223     print()
224     print("-- Original scene -----------------------------")
225     print()
226     grid_factory.print_scene(start_scene)
227     print()
228     print("-- Transformed scene --------------------------")
229     print()
230     grid_factory.print_scene(scene)
231     print()
232     print("-- Sequence -----------------------------------")
233     print()
234     print(questions)
235
236 ######################################################################