Update.
[picoclvr.git] / grid.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
13
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
15
16 ######################################################################
17
18
19 class GridFactory:
20     def __init__(
21         self,
22         size=4,
23         max_nb_items=4,
24         max_nb_transformations=3,
25         nb_questions=4,
26     ):
27         self.size = size
28         self.max_nb_items = max_nb_items
29         self.max_nb_transformations = max_nb_transformations
30         self.nb_questions = nb_questions
31
32     def generate_scene(self):
33         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
34         col = torch.full((self.size * self.size,), -1)
35         shp = torch.full((self.size * self.size,), -1)
36         a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
37         col[:nb_items] = a % len(name_colors)
38         shp[:nb_items] = a // len(name_colors)
39         i = torch.randperm(self.size * self.size)
40         col = col[i]
41         shp = shp[i]
42         return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
43
44     def random_transformations(self, scene):
45         col, shp = scene
46
47         descriptions = []
48         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
49         transformations = torch.randint(5, (nb_transformations,))
50
51         for t in transformations:
52             if t == 0:
53                 col, shp = col.flip(0), shp.flip(0)
54                 descriptions += ["<chg> vertical flip"]
55             elif t == 1:
56                 col, shp = col.flip(1), shp.flip(1)
57                 descriptions += ["<chg> horizontal flip"]
58             elif t == 2:
59                 col, shp = col.flip(0).t(), shp.flip(0).t()
60                 descriptions += ["<chg> rotate 90 degrees"]
61             elif t == 3:
62                 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
63                 descriptions += ["<chg> rotate 180 degrees"]
64             elif t == 4:
65                 col, shp = col.flip(1).t(), shp.flip(1).t()
66                 descriptions += ["<chg> rotate 270 degrees"]
67
68             col, shp = col.contiguous(), shp.contiguous()
69
70         return (col, shp), descriptions
71
72     def print_scene(self, scene):
73         col, shp = scene
74
75         # for i in range(self.size):
76         # for j in range(self.size):
77         # if col[i,j] >= 0:
78         # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
79
80         for i in range(self.size):
81             for j in range(self.size):
82                 if col[i, j] >= 0:
83                     print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
84                 elif j == 0:
85                     print(" +", end="")
86                 else:
87                     print("-+", end="")
88                 if j < self.size - 1:
89                     print("--", end="")
90                 else:
91                     print("")
92             if i < self.size - 1:
93                 for j in range(self.size - 1):
94                     print(" |  ", end="")
95                 print(" |")
96
97     def grid_positions(self, scene):
98         col, shp = scene
99
100         properties = []
101
102         for i in range(self.size):
103             for j in range(self.size):
104                 if col[i, j] >= 0:
105                     n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
106                     properties += [f"a {n} at {i} {j}"]
107
108         return properties
109
110     def all_properties(self, scene):
111         col, shp = scene
112
113         properties = []
114
115         for i1 in range(self.size):
116             for j1 in range(self.size):
117                 if col[i1, j1] >= 0:
118                     n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
119                     properties += [f"there is a {n1}"]
120                     if i1 < self.size // 2:
121                         properties += [f"a {n1} is in the top half"]
122                     if i1 >= self.size // 2:
123                         properties += [f"a {n1} is in the bottom half"]
124                     if j1 < self.size // 2:
125                         properties += [f"a {n1} is in the left half"]
126                     if j1 >= self.size // 2:
127                         properties += [f"a {n1} is in the right half"]
128                     for i2 in range(self.size):
129                         for j2 in range(self.size):
130                             if col[i2, j2] >= 0:
131                                 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
132                                 if i1 > i2:
133                                     properties += [f"a {n1} is below a {n2}"]
134                                 if i1 < i2:
135                                     properties += [f"a {n1} is above a {n2}"]
136                                 if j1 > j2:
137                                     properties += [f"a {n1} is right of a {n2}"]
138                                 if j1 < j2:
139                                     properties += [f"a {n1} is left of a {n2}"]
140
141         return properties
142
143     def generate_scene_and_questions(self):
144         while True:
145             while True:
146                 scene = self.generate_scene()
147                 true = self.all_properties(scene)
148                 if len(true) >= self.nb_questions:
149                     break
150
151             start = self.grid_positions(scene)
152
153             scene, transformations = self.random_transformations(scene)
154
155             # transformations=[]
156
157             for a in range(10):
158                 col, shp = scene
159                 col, shp = col.view(-1), shp.view(-1)
160                 p = torch.randperm(col.size(0))
161                 col, shp = col[p], shp[p]
162                 other_scene = (
163                     col.view(self.size, self.size),
164                     shp.view(self.size, self.size),
165                 )
166                 # other_scene = self.generate_scene()
167                 false = list(set(self.all_properties(other_scene)) - set(true))
168                 if len(false) >= self.nb_questions:
169                     break
170
171             if a < 10:
172                 break
173
174         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
175         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
176         true = ["<prop> " + q + " <true>" for q in true]
177         false = ["<prop> " + q + " <false>" for q in false]
178
179         union = true + false
180         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
181
182         result = " ".join(
183             ["<obj> " + x for x in self.grid_positions(scene)]
184             + transformations
185             + questions
186         )
187
188         return scene, result
189
190     def generate_samples(self, nb, progress_bar=None):
191         result = []
192
193         r = range(nb)
194         if progress_bar is not None:
195             r = progress_bar(r)
196
197         for _ in r:
198             result.append(self.generate_scene_and_questions()[1])
199
200         return result
201
202
203 ######################################################################
204
205 if __name__ == "__main__":
206     import time
207
208     grid_factory = GridFactory()
209
210     start_time = time.perf_counter()
211     samples = grid_factory.generate_samples(10000)
212     end_time = time.perf_counter()
213     print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
214
215     scene, questions = grid_factory.generate_scene_and_questions()
216     grid_factory.print_scene(scene)
217     print(questions)
218
219 ######################################################################