8da4fe8427c6e79f9e0a291d074282d8e06e3dc2
[mygpt.git] / picoclvr.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 colors = [
11     [ 255, 255, 255 ],
12     [ 255, 0, 0 ],
13     [ 0, 128, 0 ],
14     [ 0, 0, 255 ],
15     [ 255, 255, 0 ],
16     [ 0, 0, 0 ],
17     [ 128, 0, 0 ],
18     [ 139, 0, 0 ],
19     [ 165, 42, 42 ],
20     [ 178, 34, 34 ],
21     [ 220, 20, 60 ],
22     [ 255, 99, 71 ],
23     [ 255, 127, 80 ],
24     [ 205, 92, 92 ],
25     [ 240, 128, 128 ],
26     [ 233, 150, 122 ],
27     [ 250, 128, 114 ],
28     [ 255, 160, 122 ],
29     [ 255, 69, 0 ],
30     [ 255, 140, 0 ],
31     [ 255, 165, 0 ],
32     [ 255, 215, 0 ],
33     [ 184, 134, 11 ],
34     [ 218, 165, 32 ],
35     [ 238, 232, 170 ],
36     [ 189, 183, 107 ],
37     [ 240, 230, 140 ],
38     [ 128, 128, 0 ],
39     [ 154, 205, 50 ],
40     [ 85, 107, 47 ],
41     [ 107, 142, 35 ],
42     [ 124, 252, 0 ],
43     [ 127, 255, 0 ],
44     [ 173, 255, 47 ],
45     [ 0, 100, 0 ],
46     [ 34, 139, 34 ],
47     [ 0, 255, 0 ],
48     [ 50, 205, 50 ],
49     [ 144, 238, 144 ],
50     [ 152, 251, 152 ],
51     [ 143, 188, 143 ],
52     [ 0, 250, 154 ],
53     [ 0, 255, 127 ],
54     [ 46, 139, 87 ],
55     [ 102, 205, 170 ],
56     [ 60, 179, 113 ],
57     [ 32, 178, 170 ],
58     [ 47, 79, 79 ],
59     [ 0, 128, 128 ],
60     [ 0, 139, 139 ],
61     [ 0, 255, 255 ],
62     [ 0, 255, 255 ],
63     [ 224, 255, 255 ],
64     [ 0, 206, 209 ],
65     [ 64, 224, 208 ],
66     [ 72, 209, 204 ],
67     [ 175, 238, 238 ],
68     [ 127, 255, 212 ],
69     [ 176, 224, 230 ],
70     [ 95, 158, 160 ],
71     [ 70, 130, 180 ],
72     [ 100, 149, 237 ],
73     [ 0, 191, 255 ],
74     [ 30, 144, 255 ],
75     [ 173, 216, 230 ],
76     [ 135, 206, 235 ],
77     [ 135, 206, 250 ],
78     [ 25, 25, 112 ],
79     [ 0, 0, 128 ],
80     [ 0, 0, 139 ],
81     [ 0, 0, 205 ],
82     [ 65, 105, 225 ],
83     [ 138, 43, 226 ],
84     [ 75, 0, 130 ],
85     [ 72, 61, 139 ],
86     [ 106, 90, 205 ],
87     [ 123, 104, 238 ],
88     [ 147, 112, 219 ],
89     [ 139, 0, 139 ],
90     [ 148, 0, 211 ],
91     [ 153, 50, 204 ],
92     [ 186, 85, 211 ],
93     [ 128, 0, 128 ],
94     [ 216, 191, 216 ],
95     [ 221, 160, 221 ],
96     [ 238, 130, 238 ],
97     [ 255, 0, 255 ],
98     [ 218, 112, 214 ],
99     [ 199, 21, 133 ],
100     [ 219, 112, 147 ],
101     [ 255, 20, 147 ],
102     [ 255, 105, 180 ],
103     [ 255, 182, 193 ],
104     [ 255, 192, 203 ],
105     [ 250, 235, 215 ],
106     [ 245, 245, 220 ],
107     [ 255, 228, 196 ],
108     [ 255, 235, 205 ],
109     [ 245, 222, 179 ],
110     [ 255, 248, 220 ],
111     [ 255, 250, 205 ],
112     [ 250, 250, 210 ],
113     [ 255, 255, 224 ],
114     [ 139, 69, 19 ],
115     [ 160, 82, 45 ],
116     [ 210, 105, 30 ],
117     [ 205, 133, 63 ],
118     [ 244, 164, 96 ],
119     [ 222, 184, 135 ],
120     [ 210, 180, 140 ],
121     [ 188, 143, 143 ],
122     [ 255, 228, 181 ],
123     [ 255, 222, 173 ],
124     [ 255, 218, 185 ],
125     [ 255, 228, 225 ],
126     [ 255, 240, 245 ],
127     [ 250, 240, 230 ],
128     [ 253, 245, 230 ],
129     [ 255, 239, 213 ],
130     [ 255, 245, 238 ],
131     [ 245, 255, 250 ],
132     [ 112, 128, 144 ],
133     [ 119, 136, 153 ],
134     [ 176, 196, 222 ],
135     [ 230, 230, 250 ],
136     [ 255, 250, 240 ],
137     [ 240, 248, 255 ],
138     [ 248, 248, 255 ],
139     [ 240, 255, 240 ],
140     [ 255, 255, 240 ],
141     [ 240, 255, 255 ],
142     [ 255, 250, 250 ],
143     [ 192, 192, 192 ],
144     [ 220, 220, 220 ],
145     [ 245, 245, 245 ],
146 ]
147
148 color_names = [
149     'white',
150     'red',
151     'green',
152     'blue',
153     'yellow',
154     'black',
155     'maroon',
156     'dark_red',
157     'brown',
158     'firebrick',
159     'crimson',
160     'tomato',
161     'coral',
162     'indian_red',
163     'light_coral',
164     'dark_salmon',
165     'salmon',
166     'light_salmon',
167     'orange_red',
168     'dark_orange',
169     'orange',
170     'gold',
171     'dark_golden_rod',
172     'golden_rod',
173     'pale_golden_rod',
174     'dark_khaki',
175     'khaki',
176     'olive',
177     'yellow_green',
178     'dark_olive_green',
179     'olive_drab',
180     'lawn_green',
181     'chartreuse',
182     'green_yellow',
183     'dark_green',
184     'forest_green',
185     'lime',
186     'lime_green',
187     'light_green',
188     'pale_green',
189     'dark_sea_green',
190     'medium_spring_green',
191     'spring_green',
192     'sea_green',
193     'medium_aqua_marine',
194     'medium_sea_green',
195     'light_sea_green',
196     'dark_slate_gray',
197     'teal',
198     'dark_cyan',
199     'aqua',
200     'cyan',
201     'light_cyan',
202     'dark_turquoise',
203     'turquoise',
204     'medium_turquoise',
205     'pale_turquoise',
206     'aqua_marine',
207     'powder_blue',
208     'cadet_blue',
209     'steel_blue',
210     'corn_flower_blue',
211     'deep_sky_blue',
212     'dodger_blue',
213     'light_blue',
214     'sky_blue',
215     'light_sky_blue',
216     'midnight_blue',
217     'navy',
218     'dark_blue',
219     'medium_blue',
220     'royal_blue',
221     'blue_violet',
222     'indigo',
223     'dark_slate_blue',
224     'slate_blue',
225     'medium_slate_blue',
226     'medium_purple',
227     'dark_magenta',
228     'dark_violet',
229     'dark_orchid',
230     'medium_orchid',
231     'purple',
232     'thistle',
233     'plum',
234     'violet',
235     'magenta',
236     'orchid',
237     'medium_violet_red',
238     'pale_violet_red',
239     'deep_pink',
240     'hot_pink',
241     'light_pink',
242     'pink',
243     'antique_white',
244     'beige',
245     'bisque',
246     'blanched_almond',
247     'wheat',
248     'corn_silk',
249     'lemon_chiffon',
250     'light_golden_rod_yellow',
251     'light_yellow',
252     'saddle_brown',
253     'sienna',
254     'chocolate',
255     'peru',
256     'sandy_brown',
257     'burly_wood',
258     'tan',
259     'rosy_brown',
260     'moccasin',
261     'navajo_white',
262     'peach_puff',
263     'misty_rose',
264     'lavender_blush',
265     'linen',
266     'old_lace',
267     'papaya_whip',
268     'sea_shell',
269     'mint_cream',
270     'slate_gray',
271     'light_slate_gray',
272     'light_steel_blue',
273     'lavender',
274     'floral_white',
275     'alice_blue',
276     'ghost_white',
277     'honeydew',
278     'ivory',
279     'azure',
280     'snow',
281     'silver',
282     'gainsboro',
283     'white_smoke',
284 ]
285
286 color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
287
288 ######################################################################
289
290 def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False):
291
292     nb_colors =  len(color_tokens) - 1 if many_colors else max_nb_squares
293
294     descr = [ ]
295
296     for n in range(nb):
297
298         nb_squares = torch.randint(max_nb_squares, (1,)) + 1
299         square_position = torch.randperm(height * width)[:nb_squares]
300         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
301         square_i = square_position.div(width, rounding_mode = 'floor')
302         square_j = square_position % width
303
304         img = [ 0 ] * height * width
305         for k in range(nb_squares): img[square_position[k]] = square_c[k]
306
307         # generates all the true relations
308
309         s = [ ]
310
311         for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
312             s += [ f'there is {c}' ]
313
314             if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
315             if square_i[r] < height//3: s += [ f'{c} top' ]
316             if square_j[r] >= width - width//3: s += [ f'{c} right' ]
317             if square_j[r] < width//3: s += [ f'{c} left' ]
318
319             for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
320                 if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
321                 if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
322                 if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
323                 if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
324
325         # pick at most max_nb_statements at random
326
327         nb_statements = torch.randint(max_nb_statements, (1,)) + 1
328         s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] )
329         s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
330
331         descr += [ s ]
332
333     return descr
334
335 ######################################################################
336
337 def descr2img(descr, height = 6, width = 8):
338
339     def token2color(t):
340         try:
341             return color_tokens[t]
342         except KeyError:
343             return [ 128, 128, 128 ]
344
345     def img_descr(x):
346         u = x.split('<img>', 1)
347         return u[1] if len(u) > 1 else ''
348
349     img = torch.full((len(descr), 3, height, width), 255)
350     d = [ img_descr(x) for x in descr ]
351     d = [ u.strip().split(' ')[:height * width] for u in d ]
352     d = [ u + [ '<unk>' ] * (height * width - len(u)) for u in d ]
353     d = [ [ token2color(t) for t in u ] for u in d ]
354     img = torch.tensor(d).permute(0, 2, 1)
355     img = img.reshape(img.size(0), 3, height, width)
356
357     return img
358
359 ######################################################################
360
361 if __name__ == '__main__':
362     descr = generate(nb = 5)
363     for d in descr:
364         print(d)
365         print()
366
367     img = descr2img(descr)
368     print(img.size())
369
370     torchvision.utils.save_image(img / 255.,
371                                  'picoclvr_example.png', nrow = 16, pad_value = 0.8)
372
373     import time
374
375     start_time = time.perf_counter()
376     descr = generate(10000)
377     end_time = time.perf_counter()
378     print(f'{len(descr) / (end_time - start_time):.02f} samples per second')
379
380 ######################################################################