2512f98130faadd12ab35ca4f8efb8c36219734b
[mygpt.git] / picoclvr.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 colors = [
11     [ 255, 255, 255 ],
12     [ 255,   0,   0 ],
13     [   0, 255,   0 ],
14     [   0,   0, 255 ],
15     [ 255, 255,   0 ],
16     [   0,   0,   0 ],
17 ]
18
19 color_names = [
20     'white',
21     'red',
22     'green',
23     'blue',
24     'yellow',
25     'black',
26 ]
27
28 large_colors = [
29     [ 128, 0, 0 ], [ 139, 0, 0 ], [ 165, 42, 42 ], [ 178, 34, 34 ], [ 220, 20, 60 ],
30     [ 255, 0, 0 ], [ 255, 99, 71 ], [ 255, 127, 80 ], [ 205, 92, 92 ], [ 240, 128, 128 ],
31     [ 233, 150, 122 ], [ 250, 128, 114 ], [ 255, 160, 122 ], [ 255, 69, 0 ], [ 255, 140, 0 ],
32     [ 255, 165, 0 ], [ 255, 215, 0 ], [ 184, 134, 11 ], [ 218, 165, 32 ], [ 238, 232, 170 ],
33     [ 189, 183, 107 ], [ 240, 230, 140 ], [ 128, 128, 0 ], [ 255, 255, 0 ], [ 154, 205, 50 ],
34     [ 85, 107, 47 ], [ 107, 142, 35 ], [ 124, 252, 0 ], [ 127, 255, 0 ], [ 173, 255, 47 ],
35     [ 0, 100, 0 ], [ 0, 128, 0 ], [ 34, 139, 34 ], [ 0, 255, 0 ], [ 50, 205, 50 ],
36     [ 144, 238, 144 ], [ 152, 251, 152 ], [ 143, 188, 143 ], [ 0, 250, 154 ], [ 0, 255, 127 ],
37     [ 46, 139, 87 ], [ 102, 205, 170 ], [ 60, 179, 113 ], [ 32, 178, 170 ], [ 47, 79, 79 ],
38     [ 0, 128, 128 ], [ 0, 139, 139 ], [ 0, 255, 255 ], [ 0, 255, 255 ], [ 224, 255, 255 ],
39     [ 0, 206, 209 ], [ 64, 224, 208 ], [ 72, 209, 204 ], [ 175, 238, 238 ], [ 127, 255, 212 ],
40     [ 176, 224, 230 ], [ 95, 158, 160 ], [ 70, 130, 180 ], [ 100, 149, 237 ], [ 0, 191, 255 ],
41     [ 30, 144, 255 ], [ 173, 216, 230 ], [ 135, 206, 235 ], [ 135, 206, 250 ], [ 25, 25, 112 ],
42     [ 0, 0, 128 ], [ 0, 0, 139 ], [ 0, 0, 205 ], [ 0, 0, 255 ], [ 65, 105, 225 ],
43     [ 138, 43, 226 ], [ 75, 0, 130 ], [ 72, 61, 139 ], [ 106, 90, 205 ], [ 123, 104, 238 ],
44     [ 147, 112, 219 ], [ 139, 0, 139 ], [ 148, 0, 211 ], [ 153, 50, 204 ], [ 186, 85, 211 ],
45     [ 128, 0, 128 ], [ 216, 191, 216 ], [ 221, 160, 221 ], [ 238, 130, 238 ], [ 255, 0, 255 ],
46     [ 218, 112, 214 ], [ 199, 21, 133 ], [ 219, 112, 147 ], [ 255, 20, 147 ], [ 255, 105, 180 ],
47     [ 255, 182, 193 ], [ 255, 192, 203 ], [ 250, 235, 215 ], [ 245, 245, 220 ], [ 255, 228, 196 ],
48     [ 255, 235, 205 ], [ 245, 222, 179 ], [ 255, 248, 220 ], [ 255, 250, 205 ], [ 250, 250, 210 ],
49     [ 255, 255, 224 ], [ 139, 69, 19 ], [ 160, 82, 45 ], [ 210, 105, 30 ], [ 205, 133, 63 ],
50     [ 244, 164, 96 ], [ 222, 184, 135 ], [ 210, 180, 140 ], [ 188, 143, 143 ], [ 255, 228, 181 ],
51     [ 255, 222, 173 ], [ 255, 218, 185 ], [ 255, 228, 225 ], [ 255, 240, 245 ], [ 250, 240, 230 ],
52     [ 253, 245, 230 ], [ 255, 239, 213 ], [ 255, 245, 238 ], [ 245, 255, 250 ], [ 112, 128, 144 ],
53     [ 119, 136, 153 ], [ 176, 196, 222 ], [ 230, 230, 250 ], [ 255, 250, 240 ], [ 240, 248, 255 ],
54     [ 248, 248, 255 ], [ 240, 255, 240 ], [ 255, 255, 240 ], [ 240, 255, 255 ], [ 255, 250, 250 ],
55     [ 192, 192, 192 ], [ 220, 220, 220 ], [ 245, 245, 245 ],
56 ]
57
58 large_color_names = [
59     'maroon', 'dark_red', 'brown', 'firebrick', 'crimson',
60     'red', 'tomato', 'coral', 'indian_red', 'light_coral',
61     'dark_salmon', 'salmon', 'light_salmon', 'orange_red', 'dark_orange',
62     'orange', 'gold', 'dark_golden_rod', 'golden_rod', 'pale_golden_rod',
63     'dark_khaki', 'khaki', 'olive', 'yellow', 'yellow_green',
64     'dark_olive_green', 'olive_drab', 'lawn_green', 'chartreuse', 'green_yellow',
65     'dark_green', 'green', 'forest_green', 'lime', 'lime_green',
66     'light_green', 'pale_green', 'dark_sea_green', 'medium_spring_green', 'spring_green',
67     'sea_green', 'medium_aqua_marine', 'medium_sea_green', 'light_sea_green', 'dark_slate_gray',
68     'teal', 'dark_cyan', 'aqua', 'cyan', 'light_cyan',
69     'dark_turquoise', 'turquoise', 'medium_turquoise', 'pale_turquoise', 'aqua_marine',
70     'powder_blue', 'cadet_blue', 'steel_blue', 'corn_flower_blue', 'deep_sky_blue',
71     'dodger_blue', 'light_blue', 'sky_blue', 'light_sky_blue', 'midnight_blue',
72     'navy', 'dark_blue', 'medium_blue', 'blue', 'royal_blue',
73     'blue_violet', 'indigo', 'dark_slate_blue', 'slate_blue', 'medium_slate_blue',
74     'medium_purple', 'dark_magenta', 'dark_violet', 'dark_orchid', 'medium_orchid',
75     'purple', 'thistle', 'plum', 'violet', 'magenta',
76     'orchid', 'medium_violet_red', 'pale_violet_red', 'deep_pink', 'hot_pink',
77     'light_pink', 'pink', 'antique_white', 'beige', 'bisque',
78     'blanched_almond', 'wheat', 'corn_silk', 'lemon_chiffon', 'light_golden_rod_yellow',
79     'light_yellow', 'saddle_brown', 'sienna', 'chocolate', 'peru',
80     'sandy_brown', 'burly_wood', 'tan', 'rosy_brown', 'moccasin',
81     'navajo_white', 'peach_puff', 'misty_rose', 'lavender_blush', 'linen',
82     'old_lace', 'papaya_whip', 'sea_shell', 'mint_cream', 'slate_gray',
83     'light_slate_gray', 'light_steel_blue', 'lavender', 'floral_white', 'alice_blue',
84     'ghost_white', 'honeydew', 'ivory', 'azure', 'snow',
85     'silver', 'gainsboro', 'white_smoke',
86 ]
87
88 color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
89
90 ######################################################################
91
92 def generate(nb, height = 6, width = 8, max_nb_statements = 10):
93
94     descr = [ ]
95
96     for n in range(nb):
97
98         nb_squares = torch.randint(len(color_tokens) - 1, (1,)) + 1
99         square_position = torch.randperm(height * width)[:nb_squares]
100         square_c = torch.randperm(len(color_tokens) - 1)[:nb_squares] + 1
101         square_i = square_position.div(width, rounding_mode = 'floor')
102         square_j = square_position % width
103
104         img = [ 0 ] * height * width
105         for k in range(nb_squares): img[square_position[k]] = square_c[k]
106
107         # generates all the true relations
108
109         s = [ ]
110
111         for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
112             s += [ f'there is {c}' ]
113
114             if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
115             if square_i[r] < height//3: s += [ f'{c} top' ]
116             if square_j[r] >= width - width//3: s += [ f'{c} right' ]
117             if square_j[r] < width//3: s += [ f'{c} left' ]
118
119             for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
120                 if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
121                 if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
122                 if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
123                 if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
124
125         # pick at most max_nb_statements at random
126
127         nb_statements = torch.randint(max_nb_statements, (1,)) + 1
128         s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] )
129         s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
130
131         descr += [ s ]
132
133     return descr
134
135 ######################################################################
136
137 def descr2img(descr, height = 6, width = 8):
138
139     def token2color(t):
140         try:
141             return color_tokens[t]
142         except KeyError:
143             return [ 128, 128, 128 ]
144
145     def img_descr(x):
146         u = x.split('<img>', 1)
147         return u[1] if len(u) > 1 else ''
148
149     img = torch.full((len(descr), 3, height, width), 255)
150     d = [ img_descr(x) for x in descr ]
151     d = [ u.strip().split(' ')[:height * width] for u in d ]
152     d = [ u + [ '<unk>' ] * (height * width - len(u)) for u in d ]
153     d = [ [ token2color(t) for t in u ] for u in d ]
154     img = torch.tensor(d).permute(0, 2, 1)
155     img = img.reshape(img.size(0), 3, height, width)
156
157     return img
158
159 ######################################################################
160
161 if __name__ == '__main__':
162     descr = generate(5)
163     for d in descr:
164         print(d)
165         print()
166
167     img = descr2img(descr)
168     print(img.size())
169
170     torchvision.utils.save_image(img / 255.,
171                                  'picoclvr_example.png', nrow = 16, pad_value = 0.8)
172
173     import time
174
175     start_time = time.perf_counter()
176     descr = generate(10000)
177     end_time = time.perf_counter()
178     print(f'{len(descr) / (end_time - start_time):.02f} samples per second')
179
180 ######################################################################