2135710e68faa57af119b1e9193c32223a1fc2f5
[picoclvr.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 ######################################################################
13
14
15 class GridFactory:
16     def __init__(
17         self,
18         size=6,
19         max_nb_items=4,
20         max_nb_transformations=3,
21         nb_questions=4,
22         nb_shapes=6,
23         nb_colors=6,
24     ):
25         assert size % 2 == 0
26         self.size = size
27         self.max_nb_items = max_nb_items
28         self.max_nb_transformations = max_nb_transformations
29         self.nb_questions = nb_questions
30         self.name_shapes = ["A", "B", "C", "D", "E", "F"]
31         self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
32
33     def generate_scene(self):
34         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
35         col = torch.full((self.size * self.size,), -1)
36         shp = torch.full((self.size * self.size,), -1)
37         a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
38         col[:nb_items] = a % len(self.name_colors)
39         shp[:nb_items] = a // len(self.name_colors)
40         i = torch.randperm(self.size * self.size)
41         col = col[i]
42         shp = shp[i]
43         return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
44
45     def random_transformations(self, scene):
46         col, shp = scene
47
48         descriptions = []
49         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
50         transformations = torch.randint(5, (nb_transformations,))
51
52         for t in transformations:
53             if t == 0:
54                 col, shp = col.flip(0), shp.flip(0)
55                 descriptions += ["<chg> vertical flip"]
56             elif t == 1:
57                 col, shp = col.flip(1), shp.flip(1)
58                 descriptions += ["<chg> horizontal flip"]
59             elif t == 2:
60                 col, shp = col.flip(0).t(), shp.flip(0).t()
61                 descriptions += ["<chg> rotate 90 degrees"]
62             elif t == 3:
63                 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
64                 descriptions += ["<chg> rotate 180 degrees"]
65             elif t == 4:
66                 col, shp = col.flip(1).t(), shp.flip(1).t()
67                 descriptions += ["<chg> rotate 270 degrees"]
68
69             col, shp = col.contiguous(), shp.contiguous()
70
71         return (col, shp), descriptions
72
73     def print_scene(self, scene):
74         col, shp = scene
75
76         # for i in range(self.size):
77         # for j in range(self.size):
78         # if col[i,j] >= 0:
79         # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
80
81         for i in range(self.size):
82             for j in range(self.size):
83                 if col[i, j] >= 0:
84                     print(
85                         f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
86                         end="",
87                     )
88                 elif j == 0:
89                     print(" +", end="")
90                 else:
91                     print("-+", end="")
92                 if j < self.size - 1:
93                     print("--", end="")
94                 else:
95                     print("")
96             if i < self.size - 1:
97                 for j in range(self.size - 1):
98                     print(" |  ", end="")
99                 print(" |")
100
101     def grid_positions(self, scene):
102         col, shp = scene
103
104         properties = []
105
106         for i in range(self.size):
107             for j in range(self.size):
108                 if col[i, j] >= 0:
109                     n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
110                     properties += [f"a {n} at {i} {j}"]
111
112         return properties
113
114     def all_properties(self, scene):
115         col, shp = scene
116
117         properties = []
118
119         for i1 in range(self.size):
120             for j1 in range(self.size):
121                 if col[i1, j1] >= 0:
122                     n1 = (
123                         f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
124                     )
125                     properties += [f"there is a {n1}"]
126                     if i1 < self.size // 2:
127                         properties += [f"a {n1} is in the top half"]
128                     if i1 >= self.size // 2:
129                         properties += [f"a {n1} is in the bottom half"]
130                     if j1 < self.size // 2:
131                         properties += [f"a {n1} is in the left half"]
132                     if j1 >= self.size // 2:
133                         properties += [f"a {n1} is in the right half"]
134                     for i2 in range(self.size):
135                         for j2 in range(self.size):
136                             if col[i2, j2] >= 0:
137                                 n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
138                                 if i1 > i2:
139                                     properties += [f"a {n1} is below a {n2}"]
140                                 if i1 < i2:
141                                     properties += [f"a {n1} is above a {n2}"]
142                                 if j1 > j2:
143                                     properties += [f"a {n1} is right of a {n2}"]
144                                 if j1 < j2:
145                                     properties += [f"a {n1} is left of a {n2}"]
146                                 if abs(i1 - i2) + abs(j1 - j2) == 1:
147                                     properties += [f"a {n1} is next to a {n2}"]
148
149         return properties
150
151     def generate_scene_and_questions(self):
152         while True:
153             while True:
154                 start_scene = self.generate_scene()
155                 scene, transformations = self.random_transformations(start_scene)
156                 true = self.all_properties(scene)
157                 if len(true) >= self.nb_questions:
158                     break
159
160             for a in range(10):
161                 col, shp = scene
162                 col, shp = col.view(-1), shp.view(-1)
163                 p = torch.randperm(col.size(0))
164                 col, shp = col[p], shp[p]
165                 other_scene = (
166                     col.view(self.size, self.size),
167                     shp.view(self.size, self.size),
168                 )
169
170                 false = self.all_properties(other_scene)
171
172                 # We sometime add properties from a totally different
173                 # scene to have negative "there is a xxx xxx"
174                 # properties
175                 if torch.rand(1).item() < 0.2:
176                     other_scene = self.generate_scene()
177                     false += self.all_properties(other_scene)
178
179                 false = list(set(false) - set(true))
180                 if len(false) >= self.nb_questions:
181                     break
182
183             if a < 10:
184                 break
185
186         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
187         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
188         true = ["<prop> " + q + " <ans> true" for q in true]
189         false = ["<prop> " + q + " <ans> false" for q in false]
190
191         union = true + false
192         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
193
194         result = " ".join(
195             ["<obj> " + x for x in self.grid_positions(start_scene)]
196             + transformations
197             + questions
198         )
199
200         return start_scene, scene, result
201
202     def generate_samples(self, nb, progress_bar=None):
203         result = []
204
205         r = range(nb)
206         if progress_bar is not None:
207             r = progress_bar(r)
208
209         for _ in r:
210             result.append(self.generate_scene_and_questions()[2])
211
212         return result
213
214
215 ######################################################################
216
217 if __name__ == "__main__":
218     import time
219
220     grid_factory = GridFactory()
221
222     # start_time = time.perf_counter()
223     # samples = grid_factory.generate_samples(10000)
224     # end_time = time.perf_counter()
225     # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
226
227     start_scene, scene, questions = grid_factory.generate_scene_and_questions()
228     print()
229     print("-- Original scene -----------------------------")
230     print()
231     grid_factory.print_scene(start_scene)
232     print()
233     print("-- Transformed scene --------------------------")
234     print()
235     grid_factory.print_scene(scene)
236     print()
237     print("-- Sequence -----------------------------------")
238     print()
239     print(questions)
240
241 ######################################################################