fa305cf1d4d03408924928e6b029f8f8d6a305cb
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11 ######################################################################
12
13
14 class Box:
15     nb_rgb_levels = 10
16
17     def __init__(self, x, y, w, h, r, g, b):
18         self.x = x
19         self.y = y
20         self.w = w
21         self.h = h
22         self.r = r
23         self.g = g
24         self.b = b
25
26     def collision(self, scene):
27         for c in scene:
28             if (
29                 self is not c
30                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
32             ):
33                 return True
34         return False
35
36
37 ######################################################################
38
39
40 class Normalizer(nn.Module):
41     def __init__(self, mu, std):
42         super().__init__()
43         self.register_buffer("mu", mu)
44         self.register_buffer("log_var", 2 * torch.log(std))
45
46     def forward(self, x):
47         return (x - self.mu) / torch.exp(self.log_var / 2.0)
48
49
50 class SignSTE(nn.Module):
51     def __init__(self):
52         super().__init__()
53
54     def forward(self, x):
55         # torch.sign() takes three values
56         s = (x >= 0).float() * 2 - 1
57
58         if self.training:
59             u = torch.tanh(x)
60             return s + u - u.detach()
61         else:
62             return s
63
64
65 def train_encoder(
66     train_input,
67     test_input,
68     depth=2,
69     dim_hidden=48,
70     nb_bits_per_token=8,
71     lr_start=1e-3,
72     lr_end=1e-4,
73     nb_epochs=10,
74     batch_size=25,
75     logger=None,
76     device=torch.device("cpu"),
77 ):
78     if logger is None:
79         logger = lambda s: print(s)
80
81     mu, std = train_input.float().mean(), train_input.float().std()
82
83     def encoder_core(depth, dim):
84         l = [
85             [
86                 nn.Conv2d(
87                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
88                 ),
89                 nn.ReLU(),
90                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
91                 nn.ReLU(),
92             ]
93             for k in range(depth)
94         ]
95
96         return nn.Sequential(*[x for m in l for x in m])
97
98     def decoder_core(depth, dim):
99         l = [
100             [
101                 nn.ConvTranspose2d(
102                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
103                 ),
104                 nn.ReLU(),
105                 nn.ConvTranspose2d(
106                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
107                 ),
108                 nn.ReLU(),
109             ]
110             for k in range(depth - 1, -1, -1)
111         ]
112
113         return nn.Sequential(*[x for m in l for x in m])
114
115     encoder = nn.Sequential(
116         Normalizer(mu, std),
117         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
118         nn.ReLU(),
119         # 64x64
120         encoder_core(depth=depth, dim=dim_hidden),
121         # 8x8
122         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
123     )
124
125     quantizer = SignSTE()
126
127     decoder = nn.Sequential(
128         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
129         # 8x8
130         decoder_core(depth=depth, dim=dim_hidden),
131         # 64x64
132         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
133     )
134
135     model = nn.Sequential(encoder, decoder)
136
137     nb_parameters = sum(p.numel() for p in model.parameters())
138
139     logger(f"nb_parameters {nb_parameters}")
140
141     model.to(device)
142
143     for k in range(nb_epochs):
144         lr = math.exp(
145             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
146         )
147         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
148
149         acc_train_loss = 0.0
150
151         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
152             input = input.to(device)
153             z = encoder(input)
154             zq = z if k < 2 else quantizer(z)
155             output = decoder(zq)
156
157             output = output.reshape(
158                 output.size(0), -1, 3, output.size(2), output.size(3)
159             )
160
161             train_loss = F.cross_entropy(output, input)
162
163             acc_train_loss += train_loss.item() * input.size(0)
164
165             optimizer.zero_grad()
166             train_loss.backward()
167             optimizer.step()
168
169         acc_test_loss = 0.0
170
171         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
172             input = input.to(device)
173             z = encoder(input)
174             zq = z if k < 1 else quantizer(z)
175             output = decoder(zq)
176
177             output = output.reshape(
178                 output.size(0), -1, 3, output.size(2), output.size(3)
179             )
180
181             test_loss = F.cross_entropy(output, input)
182
183             acc_test_loss += test_loss.item() * input.size(0)
184
185         train_loss = acc_train_loss / train_input.size(0)
186         test_loss = acc_test_loss / test_input.size(0)
187
188         logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
189         sys.stdout.flush()
190
191     return encoder, quantizer, decoder
192
193
194 ######################################################################
195
196
197 def scene2tensor(xh, yh, scene, size):
198     width, height = size, size
199     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
200     data = pixel_map.numpy()
201     surface = cairo.ImageSurface.create_for_data(
202         data, cairo.FORMAT_ARGB32, width, height
203     )
204
205     ctx = cairo.Context(surface)
206     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
207
208     for b in scene:
209         ctx.move_to(b.x * size, b.y * size)
210         ctx.rel_line_to(b.w * size, 0)
211         ctx.rel_line_to(0, b.h * size)
212         ctx.rel_line_to(-b.w * size, 0)
213         ctx.close_path()
214         ctx.set_source_rgba(
215             b.r / (Box.nb_rgb_levels - 1),
216             b.g / (Box.nb_rgb_levels - 1),
217             b.b / (Box.nb_rgb_levels - 1),
218             1.0,
219         )
220         ctx.fill()
221
222     hs = size * 0.1
223     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
224     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
225     ctx.rel_line_to(hs, 0)
226     ctx.rel_line_to(0, hs)
227     ctx.rel_line_to(-hs, 0)
228     ctx.close_path()
229     ctx.fill()
230
231     return (
232         pixel_map[None, :, :, :3]
233         .flip(-1)
234         .permute(0, 3, 1, 2)
235         .long()
236         .mul(Box.nb_rgb_levels)
237         .floor_divide(256)
238     )
239
240
241 def random_scene():
242     scene = []
243     colors = [
244         ((Box.nb_rgb_levels - 1), 0, 0),
245         (0, (Box.nb_rgb_levels - 1), 0),
246         (0, 0, (Box.nb_rgb_levels - 1)),
247         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
248         (
249             (Box.nb_rgb_levels * 2) // 3,
250             (Box.nb_rgb_levels * 2) // 3,
251             (Box.nb_rgb_levels * 2) // 3,
252         ),
253     ]
254
255     for k in range(10):
256         wh = torch.rand(2) * 0.2 + 0.2
257         xy = torch.rand(2) * (1 - wh)
258         c = colors[torch.randint(len(colors), (1,))]
259         b = Box(
260             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
261         )
262         if not b.collision(scene):
263             scene.append(b)
264
265     return scene
266
267
268 def generate_episode(steps, size=64):
269     delta = 0.1
270     effects = [
271         (False, 0, 0),
272         (False, delta, 0),
273         (False, 0, delta),
274         (False, -delta, 0),
275         (False, 0, -delta),
276         (True, delta, 0),
277         (True, 0, delta),
278         (True, -delta, 0),
279         (True, 0, -delta),
280     ]
281
282     while True:
283         frames = []
284
285         scene = random_scene()
286         xh, yh = tuple(x.item() for x in torch.rand(2))
287
288         actions = torch.randint(len(effects), (len(steps),))
289         change = False
290
291         for s, a in zip(steps, actions):
292             if s:
293                 frames.append(scene2tensor(xh, yh, scene, size=size))
294
295             g, dx, dy = effects[a]
296             if g:
297                 for b in scene:
298                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
299                         x, y = b.x, b.y
300                         b.x += dx
301                         b.y += dy
302                         if (
303                             b.x < 0
304                             or b.y < 0
305                             or b.x + b.w > 1
306                             or b.y + b.h > 1
307                             or b.collision(scene)
308                         ):
309                             b.x, b.y = x, y
310                         else:
311                             xh += dx
312                             yh += dy
313                             change = True
314             else:
315                 x, y = xh, yh
316                 xh += dx
317                 yh += dy
318                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
319                     xh, yh = x, y
320
321         if change:
322             break
323
324     return frames, actions
325
326
327 ######################################################################
328
329
330 def generate_episodes(nb, steps):
331     all_frames, all_actions = [], []
332     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
333         frames, actions = generate_episode(steps)
334         all_frames += frames
335         all_actions += [actions[None, :]]
336     return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
337
338
339 def create_data_and_processors(
340     nb_train_samples,
341     nb_test_samples,
342     mode,
343     nb_steps,
344     nb_epochs=10,
345     device=torch.device("cpu"),
346     device_storage=torch.device("cpu"),
347     logger=None,
348 ):
349     assert mode in ["first_last"]
350
351     if mode == "first_last":
352         steps = [True] + [False] * (nb_steps + 1) + [True]
353
354     train_input, train_actions = generate_episodes(nb_train_samples, steps)
355     train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage)
356     test_input, test_actions = generate_episodes(nb_test_samples, steps)
357     test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage)
358
359     encoder, quantizer, decoder = train_encoder(
360         train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device
361     )
362     encoder.train(False)
363     quantizer.train(False)
364     decoder.train(False)
365
366     z = encoder(train_input[:1].to(device))
367     pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
368     z_h, z_w = z.size(2), z.size(3)
369
370     def frame2seq(input, batch_size=25):
371         seq = []
372         p = pow2.to(device)
373         for x in input.split(batch_size):
374             x=x.to(device)
375             z = encoder(x)
376             ze_bool = (quantizer(z) >= 0).long()
377             output = (
378                 ze_bool.permute(0, 2, 3, 1).reshape(
379                     ze_bool.size(0), -1, ze_bool.size(1)
380                 )
381                 * p
382             ).sum(-1)
383
384             seq.append(output)
385
386         return torch.cat(seq, dim=0)
387
388     def seq2frame(input, batch_size=25, T=1e-2):
389         frames = []
390         p = pow2.to(device)
391         for seq in input.split(batch_size):
392             seq = seq.to(device)
393             zd_bool = (seq[:, :, None] // p) % 2
394             zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
395             logits = decoder(zd_bool * 2.0 - 1.0)
396             logits = logits.reshape(
397                 logits.size(0), -1, 3, logits.size(2), logits.size(3)
398             ).permute(0, 2, 3, 4, 1)
399             output = torch.distributions.categorical.Categorical(
400                 logits=logits / T
401             ).sample()
402
403             frames.append(output)
404
405         return torch.cat(frames, dim=0)
406
407     return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
408
409
410 ######################################################################
411
412 if __name__ == "__main__":
413     (
414         train_input,
415         train_actions,
416         test_input,
417         test_actions,
418         frame2seq,
419         seq2frame,
420     ) = create_data_and_processors(
421         # 10000, 1000,
422         100,
423         100,
424         nb_epochs=2,
425         mode="first_last",
426         nb_steps=20,
427     )
428
429     input = test_input[:64]
430
431     seq = frame2seq(input)
432
433     print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}")
434
435     output = seq2frame(seq)
436
437     torchvision.utils.save_image(
438         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
439     )
440
441     torchvision.utils.save_image(
442         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
443     )