Update.
[mygptrnn.git] / world.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14 import cairo
15
16 ######################################################################
17
18
19 class Box:
20     nb_rgb_levels = 10
21
22     def __init__(self, x, y, w, h, r, g, b):
23         self.x = x
24         self.y = y
25         self.w = w
26         self.h = h
27         self.r = r
28         self.g = g
29         self.b = b
30
31     def collision(self, scene):
32         for c in scene:
33             if (
34                 self is not c
35                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
36                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
37             ):
38                 return True
39         return False
40
41
42 ######################################################################
43
44
45 class Normalizer(nn.Module):
46     def __init__(self, mu, std):
47         super().__init__()
48         self.register_buffer("mu", mu)
49         self.register_buffer("log_var", 2 * torch.log(std))
50
51     def forward(self, x):
52         return (x - self.mu) / torch.exp(self.log_var / 2.0)
53
54
55 class SignSTE(nn.Module):
56     def __init__(self):
57         super().__init__()
58
59     def forward(self, x):
60         # torch.sign() takes three values
61         s = (x >= 0).float() * 2 - 1
62
63         if self.training:
64             u = torch.tanh(x)
65             return s + u - u.detach()
66         else:
67             return s
68
69
70 class DiscreteSampler2d(nn.Module):
71     def __init__(self):
72         super().__init__()
73
74     def forward(self, x):
75         s = (x >= x.max(-3, keepdim=True).values).float()
76
77         if self.training:
78             u = x.softmax(dim=-3)
79             return s + u - u.detach()
80         else:
81             return s
82
83
84 def loss_H(binary_logits, h_threshold=1):
85     p = binary_logits.sigmoid().mean(0)
86     h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
87     h.clamp_(max=h_threshold)
88     return h_threshold - h.mean()
89
90
91 def train_encoder(
92     train_input,
93     test_input,
94     depth,
95     nb_bits_per_token,
96     dim_hidden=48,
97     lambda_entropy=0.0,
98     lr_start=1e-3,
99     lr_end=1e-4,
100     nb_epochs=10,
101     batch_size=25,
102     logger=None,
103     device=torch.device("cpu"),
104 ):
105     mu, std = train_input.float().mean(), train_input.float().std()
106
107     def encoder_core(depth, dim):
108         l = [
109             [
110                 nn.Conv2d(
111                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
112                 ),
113                 nn.ReLU(),
114                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
115                 nn.ReLU(),
116             ]
117             for k in range(depth)
118         ]
119
120         return nn.Sequential(*[x for m in l for x in m])
121
122     def decoder_core(depth, dim):
123         l = [
124             [
125                 nn.ConvTranspose2d(
126                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
127                 ),
128                 nn.ReLU(),
129                 nn.ConvTranspose2d(
130                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
131                 ),
132                 nn.ReLU(),
133             ]
134             for k in range(depth - 1, -1, -1)
135         ]
136
137         return nn.Sequential(*[x for m in l for x in m])
138
139     encoder = nn.Sequential(
140         Normalizer(mu, std),
141         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
142         nn.ReLU(),
143         # 64x64
144         encoder_core(depth=depth, dim=dim_hidden),
145         # 8x8
146         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
147     )
148
149     quantizer = SignSTE()
150
151     decoder = nn.Sequential(
152         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
153         # 8x8
154         decoder_core(depth=depth, dim=dim_hidden),
155         # 64x64
156         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
157     )
158
159     model = nn.Sequential(encoder, decoder)
160
161     nb_parameters = sum(p.numel() for p in model.parameters())
162
163     logger(f"vqae nb_parameters {nb_parameters}")
164
165     model.to(device)
166
167     for k in range(nb_epochs):
168         lr = math.exp(
169             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
170         )
171         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
172
173         acc_train_loss = 0.0
174
175         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
176             input = input.to(device)
177             z = encoder(input)
178             zq = quantizer(z)
179             output = decoder(zq)
180
181             output = output.reshape(
182                 output.size(0), -1, 3, output.size(2), output.size(3)
183             )
184
185             train_loss = F.cross_entropy(output, input)
186
187             if lambda_entropy > 0:
188                 train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
189
190             acc_train_loss += train_loss.item() * input.size(0)
191
192             optimizer.zero_grad()
193             train_loss.backward()
194             optimizer.step()
195
196         acc_test_loss = 0.0
197
198         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
199             input = input.to(device)
200             z = encoder(input)
201             zq = quantizer(z)
202             output = decoder(zq)
203
204             output = output.reshape(
205                 output.size(0), -1, 3, output.size(2), output.size(3)
206             )
207
208             test_loss = F.cross_entropy(output, input)
209
210             acc_test_loss += test_loss.item() * input.size(0)
211
212         train_loss = acc_train_loss / train_input.size(0)
213         test_loss = acc_test_loss / test_input.size(0)
214
215         logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
216         sys.stdout.flush()
217
218     return encoder, quantizer, decoder
219
220
221 ######################################################################
222
223
224 def scene2tensor(xh, yh, scene, size):
225     width, height = size, size
226     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
227     data = pixel_map.numpy()
228     surface = cairo.ImageSurface.create_for_data(
229         data, cairo.FORMAT_ARGB32, width, height
230     )
231
232     ctx = cairo.Context(surface)
233     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
234
235     for b in scene:
236         ctx.move_to(b.x * size, b.y * size)
237         ctx.rel_line_to(b.w * size, 0)
238         ctx.rel_line_to(0, b.h * size)
239         ctx.rel_line_to(-b.w * size, 0)
240         ctx.close_path()
241         ctx.set_source_rgba(
242             b.r / (Box.nb_rgb_levels - 1),
243             b.g / (Box.nb_rgb_levels - 1),
244             b.b / (Box.nb_rgb_levels - 1),
245             1.0,
246         )
247         ctx.fill()
248
249     hs = size * 0.1
250     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
251     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
252     ctx.rel_line_to(hs, 0)
253     ctx.rel_line_to(0, hs)
254     ctx.rel_line_to(-hs, 0)
255     ctx.close_path()
256     ctx.fill()
257
258     return (
259         pixel_map[None, :, :, :3]
260         .flip(-1)
261         .permute(0, 3, 1, 2)
262         .long()
263         .mul(Box.nb_rgb_levels)
264         .floor_divide(256)
265     )
266
267
268 def random_scene(nb_insert_attempts=3):
269     scene = []
270     colors = [
271         ((Box.nb_rgb_levels - 1), 0, 0),
272         (0, (Box.nb_rgb_levels - 1), 0),
273         (0, 0, (Box.nb_rgb_levels - 1)),
274         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
275         (
276             (Box.nb_rgb_levels * 2) // 3,
277             (Box.nb_rgb_levels * 2) // 3,
278             (Box.nb_rgb_levels * 2) // 3,
279         ),
280     ]
281
282     for k in range(nb_insert_attempts):
283         wh = torch.rand(2) * 0.2 + 0.2
284         xy = torch.rand(2) * (1 - wh)
285         c = colors[torch.randint(len(colors), (1,))]
286         b = Box(
287             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
288         )
289         if not b.collision(scene):
290             scene.append(b)
291
292     return scene
293
294
295 def generate_episode(steps, size=64):
296     delta = 0.1
297     effects = [
298         (False, 0, 0),
299         (False, delta, 0),
300         (False, 0, delta),
301         (False, -delta, 0),
302         (False, 0, -delta),
303         (True, delta, 0),
304         (True, 0, delta),
305         (True, -delta, 0),
306         (True, 0, -delta),
307     ]
308
309     while True:
310         frames = []
311
312         scene = random_scene()
313         xh, yh = tuple(x.item() for x in torch.rand(2))
314
315         actions = torch.randint(len(effects), (len(steps),))
316         nb_changes = 0
317
318         for s, a in zip(steps, actions):
319             if s:
320                 frames.append(scene2tensor(xh, yh, scene, size=size))
321
322             grasp, dx, dy = effects[a]
323
324             if grasp:
325                 for b in scene:
326                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
327                         x, y = b.x, b.y
328                         b.x += dx
329                         b.y += dy
330                         if (
331                             b.x < 0
332                             or b.y < 0
333                             or b.x + b.w > 1
334                             or b.y + b.h > 1
335                             or b.collision(scene)
336                         ):
337                             b.x, b.y = x, y
338                         else:
339                             xh += dx
340                             yh += dy
341                             nb_changes += 1
342             else:
343                 x, y = xh, yh
344                 xh += dx
345                 yh += dy
346                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
347                     xh, yh = x, y
348
349         if nb_changes > len(steps) // 3:
350             break
351
352     return frames, actions
353
354
355 ######################################################################
356
357
358 def generate_episodes(nb, steps):
359     all_frames, all_actions = [], []
360     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
361         frames, actions = generate_episode(steps)
362         all_frames += frames
363         all_actions += [actions[None, :]]
364     return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
365
366
367 def create_data_and_processors(
368     nb_train_samples,
369     nb_test_samples,
370     mode,
371     nb_steps,
372     depth=3,
373     nb_bits_per_token=8,
374     nb_epochs=10,
375     device=torch.device("cpu"),
376     device_storage=torch.device("cpu"),
377     logger=None,
378 ):
379     assert mode in ["first_last"]
380
381     if mode == "first_last":
382         steps = [True] + [False] * (nb_steps + 1) + [True]
383
384     if logger is None:
385         logger = lambda s: print(s)
386
387     train_input, train_actions = generate_episodes(nb_train_samples, steps)
388     train_input, train_actions = train_input.to(device_storage), train_actions.to(
389         device_storage
390     )
391     test_input, test_actions = generate_episodes(nb_test_samples, steps)
392     test_input, test_actions = test_input.to(device_storage), test_actions.to(
393         device_storage
394     )
395
396     encoder, quantizer, decoder = train_encoder(
397         train_input,
398         test_input,
399         depth=depth,
400         nb_bits_per_token=nb_bits_per_token,
401         lambda_entropy=1.0,
402         nb_epochs=nb_epochs,
403         logger=logger,
404         device=device,
405     )
406     encoder.train(False)
407     quantizer.train(False)
408     decoder.train(False)
409
410     z = encoder(train_input[:1].to(device))
411     pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
412     z_h, z_w = z.size(2), z.size(3)
413
414     logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
415
416     def frame2seq(input, batch_size=25):
417         seq = []
418         p = pow2.to(device)
419         for x in input.split(batch_size):
420             x = x.to(device)
421             z = encoder(x)
422             ze_bool = (quantizer(z) >= 0).long()
423             output = (
424                 ze_bool.permute(0, 2, 3, 1).reshape(
425                     ze_bool.size(0), -1, ze_bool.size(1)
426                 )
427                 * p
428             ).sum(-1)
429
430             seq.append(output)
431
432         return torch.cat(seq, dim=0)
433
434     def seq2frame(input, batch_size=25, T=1e-2):
435         frames = []
436         p = pow2.to(device)
437         for seq in input.split(batch_size):
438             seq = seq.to(device)
439             zd_bool = (seq[:, :, None] // p) % 2
440             zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
441             logits = decoder(zd_bool * 2.0 - 1.0)
442             logits = logits.reshape(
443                 logits.size(0), -1, 3, logits.size(2), logits.size(3)
444             ).permute(0, 2, 3, 4, 1)
445             output = torch.distributions.categorical.Categorical(
446                 logits=logits / T
447             ).sample()
448
449             frames.append(output)
450
451         return torch.cat(frames, dim=0)
452
453     return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
454
455
456 ######################################################################
457
458 if __name__ == "__main__":
459     (
460         train_input,
461         train_actions,
462         test_input,
463         test_actions,
464         frame2seq,
465         seq2frame,
466     ) = create_data_and_processors(
467         25000,
468         1000,
469         nb_epochs=5,
470         mode="first_last",
471         nb_steps=20,
472     )
473
474     input = test_input[:256]
475
476     seq = frame2seq(input)
477     output = seq2frame(seq)
478
479     torchvision.utils.save_image(
480         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
481     )
482
483     torchvision.utils.save_image(
484         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16
485     )