433cfd5a5babea1c7605e5dfe63ca66245f3120a
[picoclvr.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
13
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
15
16 ######################################################################
17
18
19 class GridFactory:
20     def __init__(
21         self,
22         height=4,
23         width=4,
24         max_nb_items=4,
25         max_nb_transformations=4,
26         nb_questions=4,
27     ):
28         self.height = height
29         self.width = width
30         self.max_nb_items = max_nb_items
31         self.max_nb_transformations = max_nb_transformations
32         self.nb_questions = nb_questions
33
34     def generate_scene(self):
35         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
36         col = torch.full((self.height * self.width,), -1)
37         shp = torch.full((self.height * self.width,), -1)
38         a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
39         col[:nb_items] = a % len(name_colors)
40         shp[:nb_items] = a // len(name_colors)
41         i = torch.randperm(self.height * self.width)
42         col = col[i]
43         shp = shp[i]
44         return col.reshape(self.height, self.width), shp.reshape(
45             self.height, self.width
46         )
47
48     def random_transformations(self, scene):
49         col, shp = scene
50         descriptions = []
51         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
52         transformations = torch.randint(5, (nb_transformations,))
53
54         for t in transformations:
55             if t == 0:
56                 col, shp = col.flip(0), shp.flip(0)
57                 descriptions += ["<chg> vertical flip"]
58             elif t == 1:
59                 col, shp = col.flip(1), shp.flip(1)
60                 descriptions += ["<chg> horizontal flip"]
61             elif t == 2:
62                 col, shp = col.flip(0).t(), shp.flip(0).t()
63                 descriptions += ["<chg> rotate 90 degrees"]
64             elif t == 3:
65                 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
66                 descriptions += ["<chg> rotate 180 degrees"]
67             elif t == 4:
68                 col, shp = col.flip(1).t(), shp.flip(1).t()
69                 descriptions += ["<chg> rotate 270 degrees"]
70
71         return (col.contiguous(), shp.contiguous()), descriptions
72
73     def print_scene(self, scene):
74         col, shp = scene
75
76         # for i in range(self.height):
77         # for j in range(self.width):
78         # if col[i,j] >= 0:
79         # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
80
81         for i in range(self.height):
82             for j in range(self.width):
83                 if col[i, j] >= 0:
84                     print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
85                 elif j == 0:
86                     print(" +", end="")
87                 else:
88                     print("-+", end="")
89                 if j < self.width - 1:
90                     print("--", end="")
91                 else:
92                     print("")
93             if i < self.height - 1:
94                 for j in range(self.width - 1):
95                     print(" |  ", end="")
96                 print(" |")
97
98     def grid_positions(self, scene):
99         col, shp = scene
100
101         properties = []
102
103         for i in range(self.height):
104             for j in range(self.width):
105                 if col[i, j] >= 0:
106                     n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
107                     properties += [f"a {n} at {i} {j}"]
108
109         return properties
110
111     def all_properties(self, scene):
112         col, shp = scene
113
114         properties = []
115
116         for i1 in range(self.height):
117             for j1 in range(self.width):
118                 if col[i1, j1] >= 0:
119                     n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
120                     properties += [f"there is a {n1}"]
121                     if i1 < self.height // 2:
122                         properties += [f"a {n1} is in the top half"]
123                     if i1 >= self.height // 2:
124                         properties += [f"a {n1} is in the bottom half"]
125                     if j1 < self.width // 2:
126                         properties += [f"a {n1} is in the left half"]
127                     if j1 >= self.width // 2:
128                         properties += [f"a {n1} is in the right half"]
129                     for i2 in range(self.height):
130                         for j2 in range(self.width):
131                             if col[i2, j2] >= 0:
132                                 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
133                                 if i1 > i2:
134                                     properties += [f"a {n1} is below a {n2}"]
135                                 if i1 < i2:
136                                     properties += [f"a {n1} is above a {n2}"]
137                                 if j1 > j2:
138                                     properties += [f"a {n1} is right of a {n2}"]
139                                 if j1 < j2:
140                                     properties += [f"a {n1} is left of a {n2}"]
141
142         return properties
143
144     def generate_scene_and_questions(self):
145         while True:
146             while True:
147                 scene = self.generate_scene()
148                 true = self.all_properties(scene)
149                 if len(true) >= self.nb_questions:
150                     break
151
152             start = self.grid_positions(scene)
153
154             scene, transformations = self.random_transformations(scene)
155
156             for a in range(10):
157                 col, shp = scene
158                 col, shp = col.view(-1), shp.view(-1)
159                 p = torch.randperm(col.size(0))
160                 col, shp = col[p], shp[p]
161                 other_scene = (
162                     col.view(self.height, self.width),
163                     shp.view(self.height, self.width),
164                 )
165                 # other_scene = self.generate_scene()
166                 false = list(set(self.all_properties(other_scene)) - set(true))
167                 if len(false) >= self.nb_questions:
168                     break
169
170             # print(f"{a=}")
171
172             if a < 10:
173                 break
174
175         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
176         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
177         true = ["<prop> " + q + " <true>" for q in true]
178         false = ["<prop> " + q + " <false>" for q in false]
179
180         union = true + false
181         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
182
183         result = " ".join(
184             ["<obj> " + x for x in self.grid_positions(scene)]
185             + transformations
186             + questions
187         )
188
189         return scene, result
190
191     def generate_samples(self, nb, progress_bar=None):
192         result = []
193
194         r = range(nb)
195         if progress_bar is not None:
196             r = progress_bar(r)
197
198         for _ in r:
199             result.append(self.generate_scene_and_questions()[1])
200
201         return result
202
203
204 ######################################################################
205
206 if __name__ == "__main__":
207     import time
208
209     grid_factory = GridFactory()
210
211     start_time = time.perf_counter()
212     samples = grid_factory.generate_samples(10000)
213     end_time = time.perf_counter()
214     print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
215
216     scene, questions = grid_factory.generate_scene_and_questions()
217     grid_factory.print_scene(scene)
218     print(questions)
219
220 ######################################################################