Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11 ######################################################################
12
13
14 class Box:
15     nb_rgb_levels = 10
16
17     def __init__(self, x, y, w, h, r, g, b):
18         self.x = x
19         self.y = y
20         self.w = w
21         self.h = h
22         self.r = r
23         self.g = g
24         self.b = b
25
26     def collision(self, scene):
27         for c in scene:
28             if (
29                 self is not c
30                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
32             ):
33                 return True
34         return False
35
36
37 ######################################################################
38
39
40 class Normalizer(nn.Module):
41     def __init__(self, mu, std):
42         super().__init__()
43         self.register_buffer("mu", mu)
44         self.register_buffer("log_var", 2 * torch.log(std))
45
46     def forward(self, x):
47         return (x - self.mu) / torch.exp(self.log_var / 2.0)
48
49
50 class SignSTE(nn.Module):
51     def __init__(self):
52         super().__init__()
53
54     def forward(self, x):
55         # torch.sign() takes three values
56         s = (x >= 0).float() * 2 - 1
57
58         if self.training:
59             u = torch.tanh(x)
60             return s + u - u.detach()
61         else:
62             return s
63
64
65 def train_encoder(
66     train_input,
67     test_input,
68     depth=3,
69     dim_hidden=48,
70     nb_bits_per_token=8,
71     lr_start=1e-3,
72     lr_end=1e-4,
73     nb_epochs=10,
74     batch_size=25,
75     device=torch.device("cpu"),
76 ):
77     mu, std = train_input.float().mean(), train_input.float().std()
78
79     def encoder_core(depth, dim):
80         l = [
81             [
82                 nn.Conv2d(
83                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
84                 ),
85                 nn.ReLU(),
86                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
87                 nn.ReLU(),
88             ]
89             for k in range(depth)
90         ]
91
92         return nn.Sequential(*[x for m in l for x in m])
93
94     def decoder_core(depth, dim):
95         l = [
96             [
97                 nn.ConvTranspose2d(
98                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
99                 ),
100                 nn.ReLU(),
101                 nn.ConvTranspose2d(
102                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
103                 ),
104                 nn.ReLU(),
105             ]
106             for k in range(depth - 1, -1, -1)
107         ]
108
109         return nn.Sequential(*[x for m in l for x in m])
110
111     encoder = nn.Sequential(
112         Normalizer(mu, std),
113         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
114         nn.ReLU(),
115         # 64x64
116         encoder_core(depth=depth, dim=dim_hidden),
117         # 8x8
118         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
119     )
120
121     quantizer = SignSTE()
122
123     decoder = nn.Sequential(
124         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
125         # 8x8
126         decoder_core(depth=depth, dim=dim_hidden),
127         # 64x64
128         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
129     )
130
131     model = nn.Sequential(encoder, decoder)
132
133     nb_parameters = sum(p.numel() for p in model.parameters())
134
135     print(f"nb_parameters {nb_parameters}")
136
137     model.to(device)
138
139     for k in range(nb_epochs):
140         lr = math.exp(
141             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
142         )
143         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
144
145         acc_train_loss = 0.0
146
147         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
148             z = encoder(input)
149             zq = z if k < 1 else quantizer(z)
150             output = decoder(zq)
151
152             output = output.reshape(
153                 output.size(0), -1, 3, output.size(2), output.size(3)
154             )
155
156             train_loss = F.cross_entropy(output, input)
157
158             acc_train_loss += train_loss.item() * input.size(0)
159
160             optimizer.zero_grad()
161             train_loss.backward()
162             optimizer.step()
163
164         acc_test_loss = 0.0
165
166         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
167             z = encoder(input)
168             zq = z if k < 1 else quantizer(z)
169             output = decoder(zq)
170
171             output = output.reshape(
172                 output.size(0), -1, 3, output.size(2), output.size(3)
173             )
174
175             test_loss = F.cross_entropy(output, input)
176
177             acc_test_loss += test_loss.item() * input.size(0)
178
179         train_loss = acc_train_loss / train_input.size(0)
180         test_loss = acc_test_loss / test_input.size(0)
181
182         print(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
183         sys.stdout.flush()
184
185     return encoder, quantizer, decoder
186
187
188 ######################################################################
189
190
191 def scene2tensor(xh, yh, scene, size):
192     width, height = size, size
193     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
194     data = pixel_map.numpy()
195     surface = cairo.ImageSurface.create_for_data(
196         data, cairo.FORMAT_ARGB32, width, height
197     )
198
199     ctx = cairo.Context(surface)
200     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
201
202     for b in scene:
203         ctx.move_to(b.x * size, b.y * size)
204         ctx.rel_line_to(b.w * size, 0)
205         ctx.rel_line_to(0, b.h * size)
206         ctx.rel_line_to(-b.w * size, 0)
207         ctx.close_path()
208         ctx.set_source_rgba(
209             b.r / (Box.nb_rgb_levels - 1),
210             b.g / (Box.nb_rgb_levels - 1),
211             b.b / (Box.nb_rgb_levels - 1),
212             1.0,
213         )
214         ctx.fill()
215
216     hs = size * 0.1
217     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
218     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
219     ctx.rel_line_to(hs, 0)
220     ctx.rel_line_to(0, hs)
221     ctx.rel_line_to(-hs, 0)
222     ctx.close_path()
223     ctx.fill()
224
225     return (
226         pixel_map[None, :, :, :3]
227         .flip(-1)
228         .permute(0, 3, 1, 2)
229         .long()
230         .mul(Box.nb_rgb_levels)
231         .floor_divide(256)
232     )
233
234
235 def random_scene():
236     scene = []
237     colors = [
238         ((Box.nb_rgb_levels - 1), 0, 0),
239         (0, (Box.nb_rgb_levels - 1), 0),
240         (0, 0, (Box.nb_rgb_levels - 1)),
241         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
242         (
243             (Box.nb_rgb_levels * 2) // 3,
244             (Box.nb_rgb_levels * 2) // 3,
245             (Box.nb_rgb_levels * 2) // 3,
246         ),
247     ]
248
249     for k in range(10):
250         wh = torch.rand(2) * 0.2 + 0.2
251         xy = torch.rand(2) * (1 - wh)
252         c = colors[torch.randint(len(colors), (1,))]
253         b = Box(
254             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
255         )
256         if not b.collision(scene):
257             scene.append(b)
258
259     return scene
260
261
262 def generate_episode(steps, size=64):
263     delta = 0.1
264     effects = [
265         (False, 0, 0),
266         (False, delta, 0),
267         (False, 0, delta),
268         (False, -delta, 0),
269         (False, 0, -delta),
270         (True, delta, 0),
271         (True, 0, delta),
272         (True, -delta, 0),
273         (True, 0, -delta),
274     ]
275
276     while True:
277         frames = []
278
279         scene = random_scene()
280         xh, yh = tuple(x.item() for x in torch.rand(2))
281
282         actions = torch.randint(len(effects), (len(steps),))
283         change = False
284
285         for s, a in zip(steps, actions):
286             if s:
287                 frames.append(scene2tensor(xh, yh, scene, size=size))
288
289             g, dx, dy = effects[a]
290             if g:
291                 for b in scene:
292                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
293                         x, y = b.x, b.y
294                         b.x += dx
295                         b.y += dy
296                         if (
297                             b.x < 0
298                             or b.y < 0
299                             or b.x + b.w > 1
300                             or b.y + b.h > 1
301                             or b.collision(scene)
302                         ):
303                             b.x, b.y = x, y
304                         else:
305                             xh += dx
306                             yh += dy
307                             change = True
308             else:
309                 x, y = xh, yh
310                 xh += dx
311                 yh += dy
312                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
313                     xh, yh = x, y
314
315         if change:
316             break
317
318     return frames, actions
319
320
321 ######################################################################
322
323
324 def generate_episodes(nb, steps):
325     all_frames = []
326     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
327         frames, actions = generate_episode(steps)
328         all_frames += frames
329     return torch.cat(all_frames, 0).contiguous()
330
331
332 def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10):
333     steps = [True] + [False] * 30 + [True]
334     train_input = generate_episodes(nb_train_samples, steps)
335     test_input = generate_episodes(nb_test_samples, steps)
336
337     print(f"{train_input.size()=} {test_input.size()=}")
338
339     encoder, quantizer, decoder = train_encoder(
340         train_input, test_input, nb_epochs=nb_epochs
341     )
342     encoder.train(False)
343     quantizer.train(False)
344     decoder.train(False)
345
346     z = encoder(train_input[:1])
347     pow2 = (2 ** torch.arange(z.size(1), device=z.device))[None, None, :]
348     z_h, z_w = z.size(2), z.size(3)
349
350     def frame2seq(x):
351         z = encoder(x)
352         ze_bool = (quantizer(z) >= 0).long()
353         seq = (
354             ze_bool.permute(0, 2, 3, 1).reshape(ze_bool.size(0), -1, ze_bool.size(1))
355             * pow2
356         ).sum(-1)
357         return seq
358
359     def seq2frame(seq, T=1e-2):
360         zd_bool = (seq[:, :, None] // pow2) % 2
361         zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
362         logits = decoder(zd_bool * 2.0 - 1.0)
363         logits = logits.reshape(
364             logits.size(0), -1, 3, logits.size(2), logits.size(3)
365         ).permute(0, 2, 3, 4, 1)
366         results = torch.distributions.categorical.Categorical(
367             logits=logits / T
368         ).sample()
369         return results
370
371     return train_input, test_input, frame2seq, seq2frame
372
373
374 ######################################################################
375
376 if __name__ == "__main__":
377     train_input, test_input, frame2seq, seq2frame = create_data_and_processors(
378         10000, 1000
379     )
380
381     input = test_input[:64]
382
383     seq = frame2seq(input)
384
385     print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}")
386
387     output = seq2frame(seq)
388
389     torchvision.utils.save_image(
390         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
391     )
392
393     torchvision.utils.save_image(
394         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
395     )