fb8609d82109ce2f29d0612a553bfe77d36a74c5
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11 ######################################################################
12
13
14 class Box:
15     nb_rgb_levels = 10
16
17     def __init__(self, x, y, w, h, r, g, b):
18         self.x = x
19         self.y = y
20         self.w = w
21         self.h = h
22         self.r = r
23         self.g = g
24         self.b = b
25
26     def collision(self, scene):
27         for c in scene:
28             if (
29                 self is not c
30                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
32             ):
33                 return True
34         return False
35
36
37 ######################################################################
38
39
40 class Normalizer(nn.Module):
41     def __init__(self, mu, std):
42         super().__init__()
43         self.register_buffer("mu", mu)
44         self.register_buffer("log_var", 2 * torch.log(std))
45
46     def forward(self, x):
47         return (x - self.mu) / torch.exp(self.log_var / 2.0)
48
49
50 class SignSTE(nn.Module):
51     def __init__(self):
52         super().__init__()
53
54     def forward(self, x):
55         # torch.sign() takes three values
56         s = (x >= 0).float() * 2 - 1
57
58         if self.training:
59             u = torch.tanh(x)
60             return s + u - u.detach()
61         else:
62             return s
63
64
65 def train_encoder(
66     train_input,
67     test_input,
68     depth=2,
69     dim_hidden=48,
70     nb_bits_per_token=8,
71     lr_start=1e-3,
72     lr_end=1e-4,
73     nb_epochs=10,
74     batch_size=25,
75     logger=None,
76     device=torch.device("cpu"),
77 ):
78     if logger is None:
79         logger = lambda s: print(s)
80
81     mu, std = train_input.float().mean(), train_input.float().std()
82
83     def encoder_core(depth, dim):
84         l = [
85             [
86                 nn.Conv2d(
87                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
88                 ),
89                 nn.ReLU(),
90                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
91                 nn.ReLU(),
92             ]
93             for k in range(depth)
94         ]
95
96         return nn.Sequential(*[x for m in l for x in m])
97
98     def decoder_core(depth, dim):
99         l = [
100             [
101                 nn.ConvTranspose2d(
102                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
103                 ),
104                 nn.ReLU(),
105                 nn.ConvTranspose2d(
106                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
107                 ),
108                 nn.ReLU(),
109             ]
110             for k in range(depth - 1, -1, -1)
111         ]
112
113         return nn.Sequential(*[x for m in l for x in m])
114
115     encoder = nn.Sequential(
116         Normalizer(mu, std),
117         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
118         nn.ReLU(),
119         # 64x64
120         encoder_core(depth=depth, dim=dim_hidden),
121         # 8x8
122         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
123     )
124
125     quantizer = SignSTE()
126
127     decoder = nn.Sequential(
128         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
129         # 8x8
130         decoder_core(depth=depth, dim=dim_hidden),
131         # 64x64
132         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
133     )
134
135     model = nn.Sequential(encoder, decoder)
136
137     nb_parameters = sum(p.numel() for p in model.parameters())
138
139     logger(f"nb_parameters {nb_parameters}")
140
141     model.to(device)
142
143     for k in range(nb_epochs):
144         lr = math.exp(
145             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
146         )
147         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
148
149         acc_train_loss = 0.0
150
151         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
152             z = encoder(input)
153             zq = z if k < 2 else quantizer(z)
154             output = decoder(zq)
155
156             output = output.reshape(
157                 output.size(0), -1, 3, output.size(2), output.size(3)
158             )
159
160             train_loss = F.cross_entropy(output, input)
161
162             acc_train_loss += train_loss.item() * input.size(0)
163
164             optimizer.zero_grad()
165             train_loss.backward()
166             optimizer.step()
167
168         acc_test_loss = 0.0
169
170         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
171             z = encoder(input)
172             zq = z if k < 1 else quantizer(z)
173             output = decoder(zq)
174
175             output = output.reshape(
176                 output.size(0), -1, 3, output.size(2), output.size(3)
177             )
178
179             test_loss = F.cross_entropy(output, input)
180
181             acc_test_loss += test_loss.item() * input.size(0)
182
183         train_loss = acc_train_loss / train_input.size(0)
184         test_loss = acc_test_loss / test_input.size(0)
185
186         logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
187         sys.stdout.flush()
188
189     return encoder, quantizer, decoder
190
191
192 ######################################################################
193
194
195 def scene2tensor(xh, yh, scene, size):
196     width, height = size, size
197     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
198     data = pixel_map.numpy()
199     surface = cairo.ImageSurface.create_for_data(
200         data, cairo.FORMAT_ARGB32, width, height
201     )
202
203     ctx = cairo.Context(surface)
204     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
205
206     for b in scene:
207         ctx.move_to(b.x * size, b.y * size)
208         ctx.rel_line_to(b.w * size, 0)
209         ctx.rel_line_to(0, b.h * size)
210         ctx.rel_line_to(-b.w * size, 0)
211         ctx.close_path()
212         ctx.set_source_rgba(
213             b.r / (Box.nb_rgb_levels - 1),
214             b.g / (Box.nb_rgb_levels - 1),
215             b.b / (Box.nb_rgb_levels - 1),
216             1.0,
217         )
218         ctx.fill()
219
220     hs = size * 0.1
221     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
222     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
223     ctx.rel_line_to(hs, 0)
224     ctx.rel_line_to(0, hs)
225     ctx.rel_line_to(-hs, 0)
226     ctx.close_path()
227     ctx.fill()
228
229     return (
230         pixel_map[None, :, :, :3]
231         .flip(-1)
232         .permute(0, 3, 1, 2)
233         .long()
234         .mul(Box.nb_rgb_levels)
235         .floor_divide(256)
236     )
237
238
239 def random_scene():
240     scene = []
241     colors = [
242         ((Box.nb_rgb_levels - 1), 0, 0),
243         (0, (Box.nb_rgb_levels - 1), 0),
244         (0, 0, (Box.nb_rgb_levels - 1)),
245         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
246         (
247             (Box.nb_rgb_levels * 2) // 3,
248             (Box.nb_rgb_levels * 2) // 3,
249             (Box.nb_rgb_levels * 2) // 3,
250         ),
251     ]
252
253     for k in range(10):
254         wh = torch.rand(2) * 0.2 + 0.2
255         xy = torch.rand(2) * (1 - wh)
256         c = colors[torch.randint(len(colors), (1,))]
257         b = Box(
258             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
259         )
260         if not b.collision(scene):
261             scene.append(b)
262
263     return scene
264
265
266 def generate_episode(steps, size=64):
267     delta = 0.1
268     effects = [
269         (False, 0, 0),
270         (False, delta, 0),
271         (False, 0, delta),
272         (False, -delta, 0),
273         (False, 0, -delta),
274         (True, delta, 0),
275         (True, 0, delta),
276         (True, -delta, 0),
277         (True, 0, -delta),
278     ]
279
280     while True:
281         frames = []
282
283         scene = random_scene()
284         xh, yh = tuple(x.item() for x in torch.rand(2))
285
286         actions = torch.randint(len(effects), (len(steps),))
287         change = False
288
289         for s, a in zip(steps, actions):
290             if s:
291                 frames.append(scene2tensor(xh, yh, scene, size=size))
292
293             g, dx, dy = effects[a]
294             if g:
295                 for b in scene:
296                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
297                         x, y = b.x, b.y
298                         b.x += dx
299                         b.y += dy
300                         if (
301                             b.x < 0
302                             or b.y < 0
303                             or b.x + b.w > 1
304                             or b.y + b.h > 1
305                             or b.collision(scene)
306                         ):
307                             b.x, b.y = x, y
308                         else:
309                             xh += dx
310                             yh += dy
311                             change = True
312             else:
313                 x, y = xh, yh
314                 xh += dx
315                 yh += dy
316                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
317                     xh, yh = x, y
318
319         if change:
320             break
321
322     return frames, actions
323
324
325 ######################################################################
326
327
328 def generate_episodes(nb, steps):
329     all_frames, all_actions = [], []
330     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
331         frames, actions = generate_episode(steps)
332         all_frames += frames
333         all_actions += [actions[None, :]]
334     return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
335
336
337 def create_data_and_processors(
338     nb_train_samples,
339     nb_test_samples,
340     mode,
341     nb_steps,
342     nb_epochs=10,
343     device=torch.device("cpu"),
344     logger=None,
345 ):
346     assert mode in ["first_last"]
347
348     if mode == "first_last":
349         steps = [True] + [False] * (nb_steps + 1) + [True]
350
351     train_input, train_actions = generate_episodes(nb_train_samples, steps)
352     train_input, train_actions = train_input.to(device), train_actions.to(device)
353     test_input, test_actions = generate_episodes(nb_test_samples, steps)
354     test_input, test_actions = test_input.to(device), test_actions.to(device)
355
356     encoder, quantizer, decoder = train_encoder(
357         train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device
358     )
359     encoder.train(False)
360     quantizer.train(False)
361     decoder.train(False)
362
363     z = encoder(train_input[:1])
364     pow2 = (2 ** torch.arange(z.size(1), device=z.device))[None, None, :]
365     z_h, z_w = z.size(2), z.size(3)
366
367     def frame2seq(input, batch_size=25):
368         seq = []
369
370         for x in input.split(batch_size):
371             z = encoder(x)
372             ze_bool = (quantizer(z) >= 0).long()
373             output = (
374                 ze_bool.permute(0, 2, 3, 1).reshape(
375                     ze_bool.size(0), -1, ze_bool.size(1)
376                 )
377                 * pow2
378             ).sum(-1)
379
380             seq.append(output)
381
382         return torch.cat(seq, dim=0)
383
384     def seq2frame(input, batch_size=25, T=1e-2):
385         frames = []
386
387         for seq in input.split(batch_size):
388             zd_bool = (seq[:, :, None] // pow2) % 2
389             zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
390             logits = decoder(zd_bool * 2.0 - 1.0)
391             logits = logits.reshape(
392                 logits.size(0), -1, 3, logits.size(2), logits.size(3)
393             ).permute(0, 2, 3, 4, 1)
394             output = torch.distributions.categorical.Categorical(
395                 logits=logits / T
396             ).sample()
397
398             frames.append(output)
399
400         return torch.cat(frames, dim=0)
401
402     return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
403
404
405 ######################################################################
406
407 if __name__ == "__main__":
408     (
409         train_input,
410         train_actions,
411         test_input,
412         test_actions,
413         frame2seq,
414         seq2frame,
415     ) = create_data_and_processors(
416         # 10000, 1000,
417         100,
418         100,
419         nb_epochs=2,
420         mode="first_last",
421         nb_steps=20,
422     )
423
424     input = test_input[:64]
425
426     seq = frame2seq(input)
427
428     print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}")
429
430     output = seq2frame(seq)
431
432     torchvision.utils.save_image(
433         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
434     )
435
436     torchvision.utils.save_image(
437         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
438     )