Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11 ######################################################################
12
13
14 class Box:
15     nb_rgb_levels = 10
16
17     def __init__(self, x, y, w, h, r, g, b):
18         self.x = x
19         self.y = y
20         self.w = w
21         self.h = h
22         self.r = r
23         self.g = g
24         self.b = b
25
26     def collision(self, scene):
27         for c in scene:
28             if (
29                 self is not c
30                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
32             ):
33                 return True
34         return False
35
36
37 ######################################################################
38
39
40 class Normalizer(nn.Module):
41     def __init__(self, mu, std):
42         super().__init__()
43         self.register_buffer("mu", mu)
44         self.register_buffer("log_var", 2 * torch.log(std))
45
46     def forward(self, x):
47         return (x - self.mu) / torch.exp(self.log_var / 2.0)
48
49
50 class SignSTE(nn.Module):
51     def __init__(self):
52         super().__init__()
53
54     def forward(self, x):
55         # torch.sign() takes three values
56         s = (x >= 0).float() * 2 - 1
57
58         if self.training:
59             u = torch.tanh(x)
60             return s + u - u.detach()
61         else:
62             return s
63
64
65 def loss_H(binary_logits, h_threshold=1):
66     p = binary_logits.sigmoid().mean(0)
67     h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
68     h.clamp_(max=h_threshold)
69     return h_threshold - h.mean()
70
71
72 def train_encoder(
73     train_input,
74     test_input,
75     depth=2,
76     dim_hidden=48,
77     nb_bits_per_token=8,
78     lambda_entropy=0.0,
79     lr_start=1e-3,
80     lr_end=1e-4,
81     nb_epochs=10,
82     batch_size=25,
83     logger=None,
84     device=torch.device("cpu"),
85 ):
86     if logger is None:
87         logger = lambda s: print(s)
88
89     mu, std = train_input.float().mean(), train_input.float().std()
90
91     def encoder_core(depth, dim):
92         l = [
93             [
94                 nn.Conv2d(
95                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
96                 ),
97                 nn.ReLU(),
98                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
99                 nn.ReLU(),
100             ]
101             for k in range(depth)
102         ]
103
104         return nn.Sequential(*[x for m in l for x in m])
105
106     def decoder_core(depth, dim):
107         l = [
108             [
109                 nn.ConvTranspose2d(
110                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
111                 ),
112                 nn.ReLU(),
113                 nn.ConvTranspose2d(
114                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
115                 ),
116                 nn.ReLU(),
117             ]
118             for k in range(depth - 1, -1, -1)
119         ]
120
121         return nn.Sequential(*[x for m in l for x in m])
122
123     encoder = nn.Sequential(
124         Normalizer(mu, std),
125         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
126         nn.ReLU(),
127         # 64x64
128         encoder_core(depth=depth, dim=dim_hidden),
129         # 8x8
130         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
131     )
132
133     quantizer = SignSTE()
134
135     decoder = nn.Sequential(
136         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
137         # 8x8
138         decoder_core(depth=depth, dim=dim_hidden),
139         # 64x64
140         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
141     )
142
143     model = nn.Sequential(encoder, decoder)
144
145     nb_parameters = sum(p.numel() for p in model.parameters())
146
147     logger(f"nb_parameters {nb_parameters}")
148
149     model.to(device)
150
151     for k in range(nb_epochs):
152         lr = math.exp(
153             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
154         )
155         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
156
157         acc_train_loss = 0.0
158
159         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
160             input = input.to(device)
161             z = encoder(input)
162             zq = z if k < 2 else quantizer(z)
163             output = decoder(zq)
164
165             output = output.reshape(
166                 output.size(0), -1, 3, output.size(2), output.size(3)
167             )
168
169             train_loss = F.cross_entropy(output, input)
170
171             if lambda_entropy > 0:
172                 loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5)
173
174             acc_train_loss += train_loss.item() * input.size(0)
175
176             optimizer.zero_grad()
177             train_loss.backward()
178             optimizer.step()
179
180         acc_test_loss = 0.0
181
182         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
183             input = input.to(device)
184             z = encoder(input)
185             zq = z if k < 1 else quantizer(z)
186             output = decoder(zq)
187
188             output = output.reshape(
189                 output.size(0), -1, 3, output.size(2), output.size(3)
190             )
191
192             test_loss = F.cross_entropy(output, input)
193
194             acc_test_loss += test_loss.item() * input.size(0)
195
196         train_loss = acc_train_loss / train_input.size(0)
197         test_loss = acc_test_loss / test_input.size(0)
198
199         logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
200         sys.stdout.flush()
201
202     return encoder, quantizer, decoder
203
204
205 ######################################################################
206
207
208 def scene2tensor(xh, yh, scene, size):
209     width, height = size, size
210     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
211     data = pixel_map.numpy()
212     surface = cairo.ImageSurface.create_for_data(
213         data, cairo.FORMAT_ARGB32, width, height
214     )
215
216     ctx = cairo.Context(surface)
217     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
218
219     for b in scene:
220         ctx.move_to(b.x * size, b.y * size)
221         ctx.rel_line_to(b.w * size, 0)
222         ctx.rel_line_to(0, b.h * size)
223         ctx.rel_line_to(-b.w * size, 0)
224         ctx.close_path()
225         ctx.set_source_rgba(
226             b.r / (Box.nb_rgb_levels - 1),
227             b.g / (Box.nb_rgb_levels - 1),
228             b.b / (Box.nb_rgb_levels - 1),
229             1.0,
230         )
231         ctx.fill()
232
233     hs = size * 0.1
234     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
235     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
236     ctx.rel_line_to(hs, 0)
237     ctx.rel_line_to(0, hs)
238     ctx.rel_line_to(-hs, 0)
239     ctx.close_path()
240     ctx.fill()
241
242     return (
243         pixel_map[None, :, :, :3]
244         .flip(-1)
245         .permute(0, 3, 1, 2)
246         .long()
247         .mul(Box.nb_rgb_levels)
248         .floor_divide(256)
249     )
250
251
252 def random_scene(nb_insert_attempts=3):
253     scene = []
254     colors = [
255         ((Box.nb_rgb_levels - 1), 0, 0),
256         (0, (Box.nb_rgb_levels - 1), 0),
257         (0, 0, (Box.nb_rgb_levels - 1)),
258         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
259         (
260             (Box.nb_rgb_levels * 2) // 3,
261             (Box.nb_rgb_levels * 2) // 3,
262             (Box.nb_rgb_levels * 2) // 3,
263         ),
264     ]
265
266     for k in range(nb_insert_attempts):
267         wh = torch.rand(2) * 0.2 + 0.2
268         xy = torch.rand(2) * (1 - wh)
269         c = colors[torch.randint(len(colors), (1,))]
270         b = Box(
271             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
272         )
273         if not b.collision(scene):
274             scene.append(b)
275
276     return scene
277
278
279 def generate_episode(steps, size=64):
280     delta = 0.1
281     effects = [
282         (False, 0, 0),
283         (False, delta, 0),
284         (False, 0, delta),
285         (False, -delta, 0),
286         (False, 0, -delta),
287         (True, delta, 0),
288         (True, 0, delta),
289         (True, -delta, 0),
290         (True, 0, -delta),
291     ]
292
293     while True:
294         frames = []
295
296         scene = random_scene()
297         xh, yh = tuple(x.item() for x in torch.rand(2))
298
299         actions = torch.randint(len(effects), (len(steps),))
300         nb_changes = 0
301
302         for s, a in zip(steps, actions):
303             if s:
304                 frames.append(scene2tensor(xh, yh, scene, size=size))
305
306             grasp, dx, dy = effects[a]
307
308             if grasp:
309                 for b in scene:
310                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
311                         x, y = b.x, b.y
312                         b.x += dx
313                         b.y += dy
314                         if (
315                             b.x < 0
316                             or b.y < 0
317                             or b.x + b.w > 1
318                             or b.y + b.h > 1
319                             or b.collision(scene)
320                         ):
321                             b.x, b.y = x, y
322                         else:
323                             xh += dx
324                             yh += dy
325                             nb_changes += 1
326             else:
327                 x, y = xh, yh
328                 xh += dx
329                 yh += dy
330                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
331                     xh, yh = x, y
332
333         if nb_changes > len(steps) // 3:
334             break
335
336     return frames, actions
337
338
339 ######################################################################
340
341
342 def generate_episodes(nb, steps):
343     all_frames, all_actions = [], []
344     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
345         frames, actions = generate_episode(steps)
346         all_frames += frames
347         all_actions += [actions[None, :]]
348     return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
349
350
351 def create_data_and_processors(
352     nb_train_samples,
353     nb_test_samples,
354     mode,
355     nb_steps,
356     nb_epochs=10,
357     device=torch.device("cpu"),
358     device_storage=torch.device("cpu"),
359     logger=None,
360 ):
361     assert mode in ["first_last"]
362
363     if mode == "first_last":
364         steps = [True] + [False] * (nb_steps + 1) + [True]
365
366     train_input, train_actions = generate_episodes(nb_train_samples, steps)
367     train_input, train_actions = train_input.to(device_storage), train_actions.to(
368         device_storage
369     )
370     test_input, test_actions = generate_episodes(nb_test_samples, steps)
371     test_input, test_actions = test_input.to(device_storage), test_actions.to(
372         device_storage
373     )
374
375     encoder, quantizer, decoder = train_encoder(
376         train_input,
377         test_input,
378         lambda_entropy=1.0,
379         nb_epochs=nb_epochs,
380         logger=logger,
381         device=device,
382     )
383     encoder.train(False)
384     quantizer.train(False)
385     decoder.train(False)
386
387     z = encoder(train_input[:1].to(device))
388     pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
389     z_h, z_w = z.size(2), z.size(3)
390
391     def frame2seq(input, batch_size=25):
392         seq = []
393         p = pow2.to(device)
394         for x in input.split(batch_size):
395             x = x.to(device)
396             z = encoder(x)
397             ze_bool = (quantizer(z) >= 0).long()
398             output = (
399                 ze_bool.permute(0, 2, 3, 1).reshape(
400                     ze_bool.size(0), -1, ze_bool.size(1)
401                 )
402                 * p
403             ).sum(-1)
404
405             seq.append(output)
406
407         return torch.cat(seq, dim=0)
408
409     def seq2frame(input, batch_size=25, T=1e-2):
410         frames = []
411         p = pow2.to(device)
412         for seq in input.split(batch_size):
413             seq = seq.to(device)
414             zd_bool = (seq[:, :, None] // p) % 2
415             zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
416             logits = decoder(zd_bool * 2.0 - 1.0)
417             logits = logits.reshape(
418                 logits.size(0), -1, 3, logits.size(2), logits.size(3)
419             ).permute(0, 2, 3, 4, 1)
420             output = torch.distributions.categorical.Categorical(
421                 logits=logits / T
422             ).sample()
423
424             frames.append(output)
425
426         return torch.cat(frames, dim=0)
427
428     return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
429
430
431 ######################################################################
432
433 if __name__ == "__main__":
434     (
435         train_input,
436         train_actions,
437         test_input,
438         test_actions,
439         frame2seq,
440         seq2frame,
441     ) = create_data_and_processors(
442         # 10000, 1000,
443         100,
444         100,
445         nb_epochs=2,
446         mode="first_last",
447         nb_steps=20,
448     )
449
450     input = test_input[:64]
451
452     seq = frame2seq(input)
453
454     print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}")
455
456     output = seq2frame(seq)
457
458     torchvision.utils.save_image(
459         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
460     )
461
462     torchvision.utils.save_image(
463         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
464     )