Update.
[mygpt.git] / picoclvr.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 colors = [
11     [ 255, 255, 255 ],
12     [ 255,   0,   0 ],
13     [   0, 255,   0 ],
14     [   0,   0, 255 ],
15     [ 255, 255,   0 ],
16     [   0,   0,   0 ],
17 ]
18
19 color_names = [
20     'white',
21     'red',
22     'green',
23     'blue',
24     'yellow',
25     'black',
26 ]
27
28 color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
29
30 ######################################################################
31
32 def generate(nb, height = 6, width = 8, max_nb_statements = 10):
33
34     descr = [ ]
35
36     for n in range(nb):
37
38         nb_squares = torch.randint(len(color_tokens) - 1, (1,)) + 1
39         square_position = torch.randperm(height * width)[:nb_squares]
40         square_c = torch.randperm(len(color_tokens) - 1)[:nb_squares] + 1
41         square_i = square_position.div(width, rounding_mode = 'floor')
42         square_j = square_position % width
43
44         img = [ 0 ] * height * width
45         for k in range(nb_squares): img[square_position[k]] = square_c[k]
46
47         # generates all the true relations
48
49         s = [ ]
50
51         for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
52             s += [ f'there is {c}' ]
53
54             if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
55             if square_i[r] < height//3: s += [ f'{c} top' ]
56             if square_j[r] >= width - width//3: s += [ f'{c} right' ]
57             if square_j[r] < width//3: s += [ f'{c} left' ]
58
59             for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
60                 if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
61                 if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
62                 if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
63                 if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
64
65         # pick at most max_nb_statements at random
66
67         nb_statements = torch.randint(max_nb_statements, (1,)) + 1
68         s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] )
69         s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
70
71         descr += [ s ]
72
73     return descr
74
75 ######################################################################
76
77 def descr2img(descr, height = 6, width = 8):
78
79     def token2color(t):
80         try:
81             return color_tokens[t]
82         except KeyError:
83             return [ 128, 128, 128 ]
84
85     def img_descr(x):
86         u = x.split('<img>', 1)
87         return u[1] if len(u) > 1 else ''
88
89     img = torch.full((len(descr), 3, height, width), 255)
90     d = [ img_descr(x) for x in descr ]
91     d = [ u.strip().split(' ')[:height * width] for u in d ]
92     d = [ u + [ '<unk>' ] * (height * width - len(u)) for u in d ]
93     d = [ [ token2color(t) for t in u ] for u in d ]
94     img = torch.tensor(d).permute(0, 2, 1)
95     img = img.reshape(img.size(0), 3, height, width)
96
97     return img
98
99 ######################################################################
100
101 if __name__ == '__main__':
102     descr = generate(5)
103     for d in descr:
104         print(d)
105         print()
106
107     img = descr2img(descr)
108     print(img.size())
109
110     torchvision.utils.save_image(img / 255.,
111                                  'picoclvr_example.png', nrow = 16, pad_value = 0.8)
112
113     import time
114
115     start_time = time.perf_counter()
116     descr = generate(10000)
117     end_time = time.perf_counter()
118     print(f'{len(descr) / (end_time - start_time):.02f} samples per second')
119
120 ######################################################################