5 import torch, torchvision
8 from torch.nn import functional as F
11 ######################################################################
17 def __init__(self, x, y, w, h, r, g, b):
26 def collision(self, scene):
30 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
37 ######################################################################
40 class Normalizer(nn.Module):
41 def __init__(self, mu, std):
43 self.register_buffer("mu", mu)
44 self.register_buffer("log_var", 2 * torch.log(std))
47 return (x - self.mu) / torch.exp(self.log_var / 2.0)
50 class SignSTE(nn.Module):
55 # torch.sign() takes three values
56 s = (x >= 0).float() * 2 - 1
60 return s + u - u.detach()
65 class DiscreteSampler2d(nn.Module):
70 s = (x >= x.max(-3, keepdim=True).values).float()
74 return s + u - u.detach()
79 def loss_H(binary_logits, h_threshold=1):
80 p = binary_logits.sigmoid().mean(0)
81 h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
82 h.clamp_(max=h_threshold)
83 return h_threshold - h.mean()
98 device=torch.device("cpu"),
100 mu, std = train_input.float().mean(), train_input.float().std()
102 def encoder_core(depth, dim):
106 dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
109 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
112 for k in range(depth)
115 return nn.Sequential(*[x for m in l for x in m])
117 def decoder_core(depth, dim):
121 dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
125 dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
129 for k in range(depth - 1, -1, -1)
132 return nn.Sequential(*[x for m in l for x in m])
134 encoder = nn.Sequential(
136 nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
139 encoder_core(depth=depth, dim=dim_hidden),
141 nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
144 quantizer = SignSTE()
146 decoder = nn.Sequential(
147 nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
149 decoder_core(depth=depth, dim=dim_hidden),
151 nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
154 model = nn.Sequential(encoder, decoder)
156 nb_parameters = sum(p.numel() for p in model.parameters())
158 logger(f"vqae nb_parameters {nb_parameters}")
162 for k in range(nb_epochs):
164 math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
166 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
170 for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
171 input = input.to(device)
176 output = output.reshape(
177 output.size(0), -1, 3, output.size(2), output.size(3)
180 train_loss = F.cross_entropy(output, input)
182 if lambda_entropy > 0:
183 train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
185 acc_train_loss += train_loss.item() * input.size(0)
187 optimizer.zero_grad()
188 train_loss.backward()
193 for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
194 input = input.to(device)
199 output = output.reshape(
200 output.size(0), -1, 3, output.size(2), output.size(3)
203 test_loss = F.cross_entropy(output, input)
205 acc_test_loss += test_loss.item() * input.size(0)
207 train_loss = acc_train_loss / train_input.size(0)
208 test_loss = acc_test_loss / test_input.size(0)
210 logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
213 return encoder, quantizer, decoder
216 ######################################################################
219 def scene2tensor(xh, yh, scene, size):
220 width, height = size, size
221 pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
222 data = pixel_map.numpy()
223 surface = cairo.ImageSurface.create_for_data(
224 data, cairo.FORMAT_ARGB32, width, height
227 ctx = cairo.Context(surface)
228 ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
231 ctx.move_to(b.x * size, b.y * size)
232 ctx.rel_line_to(b.w * size, 0)
233 ctx.rel_line_to(0, b.h * size)
234 ctx.rel_line_to(-b.w * size, 0)
237 b.r / (Box.nb_rgb_levels - 1),
238 b.g / (Box.nb_rgb_levels - 1),
239 b.b / (Box.nb_rgb_levels - 1),
245 ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
246 ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
247 ctx.rel_line_to(hs, 0)
248 ctx.rel_line_to(0, hs)
249 ctx.rel_line_to(-hs, 0)
254 pixel_map[None, :, :, :3]
258 .mul(Box.nb_rgb_levels)
263 def random_scene(nb_insert_attempts=3):
266 ((Box.nb_rgb_levels - 1), 0, 0),
267 (0, (Box.nb_rgb_levels - 1), 0),
268 (0, 0, (Box.nb_rgb_levels - 1)),
269 ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
271 (Box.nb_rgb_levels * 2) // 3,
272 (Box.nb_rgb_levels * 2) // 3,
273 (Box.nb_rgb_levels * 2) // 3,
277 for k in range(nb_insert_attempts):
278 wh = torch.rand(2) * 0.2 + 0.2
279 xy = torch.rand(2) * (1 - wh)
280 c = colors[torch.randint(len(colors), (1,))]
282 xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
284 if not b.collision(scene):
290 def generate_episode(steps, size=64):
307 scene = random_scene()
308 xh, yh = tuple(x.item() for x in torch.rand(2))
310 actions = torch.randint(len(effects), (len(steps),))
313 for s, a in zip(steps, actions):
315 frames.append(scene2tensor(xh, yh, scene, size=size))
317 grasp, dx, dy = effects[a]
321 if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
330 or b.collision(scene)
341 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
344 if nb_changes > len(steps) // 3:
347 return frames, actions
350 ######################################################################
353 def generate_episodes(nb, steps):
354 all_frames, all_actions = [], []
355 for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
356 frames, actions = generate_episode(steps)
358 all_actions += [actions[None, :]]
359 return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
362 def create_data_and_processors(
370 device=torch.device("cpu"),
371 device_storage=torch.device("cpu"),
374 assert mode in ["first_last"]
376 if mode == "first_last":
377 steps = [True] + [False] * (nb_steps + 1) + [True]
380 logger = lambda s: print(s)
382 train_input, train_actions = generate_episodes(nb_train_samples, steps)
383 train_input, train_actions = train_input.to(device_storage), train_actions.to(
386 test_input, test_actions = generate_episodes(nb_test_samples, steps)
387 test_input, test_actions = test_input.to(device_storage), test_actions.to(
391 encoder, quantizer, decoder = train_encoder(
395 nb_bits_per_token=nb_bits_per_token,
402 quantizer.train(False)
405 z = encoder(train_input[:1].to(device))
406 pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
407 z_h, z_w = z.size(2), z.size(3)
409 logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
411 def frame2seq(input, batch_size=25):
414 for x in input.split(batch_size):
417 ze_bool = (quantizer(z) >= 0).long()
419 ze_bool.permute(0, 2, 3, 1).reshape(
420 ze_bool.size(0), -1, ze_bool.size(1)
427 return torch.cat(seq, dim=0)
429 def seq2frame(input, batch_size=25, T=1e-2):
432 for seq in input.split(batch_size):
434 zd_bool = (seq[:, :, None] // p) % 2
435 zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
436 logits = decoder(zd_bool * 2.0 - 1.0)
437 logits = logits.reshape(
438 logits.size(0), -1, 3, logits.size(2), logits.size(3)
439 ).permute(0, 2, 3, 4, 1)
440 output = torch.distributions.categorical.Categorical(
444 frames.append(output)
446 return torch.cat(frames, dim=0)
448 return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
451 ######################################################################
453 if __name__ == "__main__":
461 ) = create_data_and_processors(
469 input = test_input[:256]
471 seq = frame2seq(input)
472 output = seq2frame(seq)
474 torchvision.utils.save_image(
475 input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
478 torchvision.utils.save_image(
479 output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16