Update.
[mygptrnn.git] / picoclvr.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import torch, torchvision
10 import torch.nn.functional as F
11
12 color_name2rgb = {
13     "white": [255, 255, 255],
14     "red": [255, 0, 0],
15     "green": [0, 128, 0],
16     "blue": [0, 0, 255],
17     "yellow": [255, 255, 0],
18     "black": [0, 0, 0],
19     "maroon": [128, 0, 0],
20     "dark_red": [139, 0, 0],
21     "brown": [165, 42, 42],
22     "firebrick": [178, 34, 34],
23     "crimson": [220, 20, 60],
24     "tomato": [255, 99, 71],
25     "coral": [255, 127, 80],
26     "indian_red": [205, 92, 92],
27     "light_coral": [240, 128, 128],
28     "dark_salmon": [233, 150, 122],
29     "salmon": [250, 128, 114],
30     "light_salmon": [255, 160, 122],
31     "orange_red": [255, 69, 0],
32     "dark_orange": [255, 140, 0],
33     "orange": [255, 165, 0],
34     "gold": [255, 215, 0],
35     "dark_golden_rod": [184, 134, 11],
36     "golden_rod": [218, 165, 32],
37     "pale_golden_rod": [238, 232, 170],
38     "dark_khaki": [189, 183, 107],
39     "khaki": [240, 230, 140],
40     "olive": [128, 128, 0],
41     "yellow_green": [154, 205, 50],
42     "dark_olive_green": [85, 107, 47],
43     "olive_drab": [107, 142, 35],
44     "lawn_green": [124, 252, 0],
45     "chartreuse": [127, 255, 0],
46     "green_yellow": [173, 255, 47],
47     "dark_green": [0, 100, 0],
48     "forest_green": [34, 139, 34],
49     "lime": [0, 255, 0],
50     "lime_green": [50, 205, 50],
51     "light_green": [144, 238, 144],
52     "pale_green": [152, 251, 152],
53     "dark_sea_green": [143, 188, 143],
54     "medium_spring_green": [0, 250, 154],
55     "spring_green": [0, 255, 127],
56     "sea_green": [46, 139, 87],
57     "medium_aqua_marine": [102, 205, 170],
58     "medium_sea_green": [60, 179, 113],
59     "light_sea_green": [32, 178, 170],
60     "dark_slate_gray": [47, 79, 79],
61     "teal": [0, 128, 128],
62     "dark_cyan": [0, 139, 139],
63     "aqua": [0, 255, 255],
64     "cyan": [0, 255, 255],
65     "light_cyan": [224, 255, 255],
66     "dark_turquoise": [0, 206, 209],
67     "turquoise": [64, 224, 208],
68     "medium_turquoise": [72, 209, 204],
69     "pale_turquoise": [175, 238, 238],
70     "aqua_marine": [127, 255, 212],
71     "powder_blue": [176, 224, 230],
72     "cadet_blue": [95, 158, 160],
73     "steel_blue": [70, 130, 180],
74     "corn_flower_blue": [100, 149, 237],
75     "deep_sky_blue": [0, 191, 255],
76     "dodger_blue": [30, 144, 255],
77     "light_blue": [173, 216, 230],
78     "sky_blue": [135, 206, 235],
79     "light_sky_blue": [135, 206, 250],
80     "midnight_blue": [25, 25, 112],
81     "navy": [0, 0, 128],
82     "dark_blue": [0, 0, 139],
83     "medium_blue": [0, 0, 205],
84     "royal_blue": [65, 105, 225],
85     "blue_violet": [138, 43, 226],
86     "indigo": [75, 0, 130],
87     "dark_slate_blue": [72, 61, 139],
88     "slate_blue": [106, 90, 205],
89     "medium_slate_blue": [123, 104, 238],
90     "medium_purple": [147, 112, 219],
91     "dark_magenta": [139, 0, 139],
92     "dark_violet": [148, 0, 211],
93     "dark_orchid": [153, 50, 204],
94     "medium_orchid": [186, 85, 211],
95     "purple": [128, 0, 128],
96     "thistle": [216, 191, 216],
97     "plum": [221, 160, 221],
98     "violet": [238, 130, 238],
99     "magenta": [255, 0, 255],
100     "orchid": [218, 112, 214],
101     "medium_violet_red": [199, 21, 133],
102     "pale_violet_red": [219, 112, 147],
103     "deep_pink": [255, 20, 147],
104     "hot_pink": [255, 105, 180],
105     "light_pink": [255, 182, 193],
106     "pink": [255, 192, 203],
107     "antique_white": [250, 235, 215],
108     "beige": [245, 245, 220],
109     "bisque": [255, 228, 196],
110     "blanched_almond": [255, 235, 205],
111     "wheat": [245, 222, 179],
112     "corn_silk": [255, 248, 220],
113     "lemon_chiffon": [255, 250, 205],
114     "light_golden_rod_yellow": [250, 250, 210],
115     "light_yellow": [255, 255, 224],
116     "saddle_brown": [139, 69, 19],
117     "sienna": [160, 82, 45],
118     "chocolate": [210, 105, 30],
119     "peru": [205, 133, 63],
120     "sandy_brown": [244, 164, 96],
121     "burly_wood": [222, 184, 135],
122     "tan": [210, 180, 140],
123     "rosy_brown": [188, 143, 143],
124     "moccasin": [255, 228, 181],
125     "navajo_white": [255, 222, 173],
126     "peach_puff": [255, 218, 185],
127     "misty_rose": [255, 228, 225],
128     "lavender_blush": [255, 240, 245],
129     "linen": [250, 240, 230],
130     "old_lace": [253, 245, 230],
131     "papaya_whip": [255, 239, 213],
132     "sea_shell": [255, 245, 238],
133     "mint_cream": [245, 255, 250],
134     "slate_gray": [112, 128, 144],
135     "light_slate_gray": [119, 136, 153],
136     "light_steel_blue": [176, 196, 222],
137     "lavender": [230, 230, 250],
138     "floral_white": [255, 250, 240],
139     "alice_blue": [240, 248, 255],
140     "ghost_white": [248, 248, 255],
141     "honeydew": [240, 255, 240],
142     "ivory": [255, 255, 240],
143     "azure": [240, 255, 255],
144     "snow": [255, 250, 250],
145     "silver": [192, 192, 192],
146     "gainsboro": [220, 220, 220],
147     "white_smoke": [245, 245, 245],
148 }
149
150 color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
151 color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
152
153 ######################################################################
154
155
156 def all_properties(height, width, nb_squares, square_i, square_j, square_c):
157     s = []
158
159     for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
160         s += [f"there is {c_r}"]
161
162         if square_i[r] >= height - height // 3:
163             s += [f"{c_r} bottom"]
164         if square_i[r] < height // 3:
165             s += [f"{c_r} top"]
166         if square_j[r] >= width - width // 3:
167             s += [f"{c_r} right"]
168         if square_j[r] < width // 3:
169             s += [f"{c_r} left"]
170
171         for t, c_t in [
172             (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
173         ]:
174             if square_i[r] > square_i[t]:
175                 s += [f"{c_r} below {c_t}"]
176             if square_i[r] < square_i[t]:
177                 s += [f"{c_r} above {c_t}"]
178             if square_j[r] > square_j[t]:
179                 s += [f"{c_r} right of {c_t}"]
180             if square_j[r] < square_j[t]:
181                 s += [f"{c_r} left of {c_t}"]
182
183     return s
184
185
186 ######################################################################
187
188 # Generates sequences
189
190
191 def generate(
192     nb,
193     height,
194     width,
195     max_nb_squares=5,
196     max_nb_properties=10,
197     nb_colors=5,
198     pruner=None,
199 ):
200     assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
201
202     descr = []
203
204     for n in range(nb):
205         # we want uniform over the combinations of 1 to max_nb_squares
206         # pixels of nb_colors
207         logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
208         dist = torch.distributions.categorical.Categorical(logits=logits)
209         nb_squares = dist.sample((1,)) + 1
210         # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
211         square_position = torch.randperm(height * width)[:nb_squares]
212
213         # color 0 is white and reserved for the background
214         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
215         square_i = square_position.div(width, rounding_mode="floor")
216         square_j = square_position % width
217
218         img = torch.zeros(height * width, dtype=torch.int64)
219         for k in range(nb_squares):
220             img[square_position[k]] = square_c[k]
221
222         # generates all the true properties
223
224         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
225
226         if pruner is not None:
227             s = list(filter(pruner, s))
228
229         # picks at most max_nb_properties at random
230
231         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
232         s = (
233             " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
234             + " <img> "
235             + " ".join([f"{color_id2name[n.item()]}" for n in img])
236         )
237
238         descr += [s]
239
240     return descr
241
242
243 ######################################################################
244
245 # Extracts the image after <img> in descr as a 1x3xHxW tensor
246
247
248 def descr2img(descr, height, width):
249     result = []
250
251     def token2color(t):
252         try:
253             return color_name2rgb[t]
254         except KeyError:
255             return [128, 128, 128]
256
257     for d in descr:
258         d = d.split("<img>")[1]
259         d = d.strip().split(" ")[: height * width]
260         d = d + ["<unk>"] * (height * width - len(d))
261         d = [token2color(t) for t in d]
262         img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
263         result.append(img)
264
265     return torch.cat(result, 0)
266
267
268 ######################################################################
269
270 # Returns all the properties of the image after <img> in descr
271
272
273 def descr2properties(descr, height, width):
274     if type(descr) == list:
275         return [descr2properties(d, height, width) for d in descr]
276
277     d = descr.split("<img>")
278     img_tokens = d[-1] if len(d) > 1 else ""
279     img_tokens = img_tokens.strip().split(" ")[: height * width]
280     if len(img_tokens) != height * width:
281         return []
282
283     seen = {}
284     for k, x in enumerate(img_tokens):
285         if x != color_id2name[0]:
286             if x in color_name2rgb:
287                 if x in seen:
288                     return []
289             else:
290                 return []
291             seen[x] = (color_name2id[x], k // width, k % width)
292
293     square_infos = tuple(zip(*seen.values()))
294
295     if square_infos:
296         square_c = torch.tensor(square_infos[0])
297         square_i = torch.tensor(square_infos[1])
298         square_j = torch.tensor(square_infos[2])
299     else:
300         square_c = torch.tensor([])
301         square_i = torch.tensor([])
302         square_j = torch.tensor([])
303
304     s = all_properties(height, width, len(seen), square_i, square_j, square_c)
305
306     return s
307
308
309 ######################################################################
310
311 # Returns a triplet composed of (1) the total number of properties
312 # before <img> in descr, (2) the total number of properties the image
313 # after <img> verifies, and (3) the number of properties in (1) not in
314 # (2)
315
316
317 def nb_properties(descr, height, width, pruner=None):
318     if type(descr) == list:
319         return [nb_properties(d, height, width, pruner) for d in descr]
320
321     d = descr.split("<img>", 1)
322     if len(d) == 0:
323         return 0
324     d = d[0].strip().split("<sep>")
325     d = [x.strip() for x in d]
326
327     all_properties = set(descr2properties(descr, height, width))
328
329     if pruner is None:
330         requested_properties = set(d)
331     else:
332         requested_properties = set(filter(pruner, d))
333
334     missing_properties = requested_properties - all_properties
335
336     return (len(requested_properties), len(all_properties), len(missing_properties))
337
338
339 ######################################################################
340
341 if __name__ == "__main__":
342     for n in range(16):
343         descr = generate(nb=1, height=12, width=16)
344
345         print(nb_properties(descr, height=12, width=16))
346
347         with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
348             for d in descr:
349                 f.write(f"{d}\n\n")
350
351         img = descr2img(descr, height=12, width=16)
352         if img.size(0) == 1:
353             img = F.pad(img, (1, 1, 1, 1), value=64)
354
355         torchvision.utils.save_image(
356             img / 255.0,
357             f"picoclvr_example_{n:02d}.png",
358             padding=1,
359             nrow=4,
360             pad_value=0.8,
361         )
362
363     import time
364
365     start_time = time.perf_counter()
366     descr = generate(nb=1000, height=12, width=16)
367     end_time = time.perf_counter()
368     print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
369
370 ######################################################################