Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 try:
18     from graph import save_attention_image
19 except ImportError:
20     save_attention_image = None
21
22 ######################################################################
23
24
25 def masked_inplace_autoregression(
26     model,
27     batch_size,
28     input,
29     ar_mask,
30     deterministic_synthesis,
31     forbidden_tokens=None,
32     progress_bar_desc="autoregression",
33     device=torch.device("cpu"),
34 ):
35     assert input.size() == ar_mask.size()
36
37     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
38
39     if progress_bar_desc is not None:
40         batches = tqdm.tqdm(
41             batches,
42             dynamic_ncols=True,
43             desc=progress_bar_desc,
44             total=(input.size(0) + batch_size - 1) // batch_size,
45         )
46
47     with torch.autograd.no_grad():
48         t = model.training
49         model.eval()
50
51         for input, ar_mask in batches:
52             model.masked_inplace_autoregression(
53                 input, ar_mask, forbidden_tokens, deterministic_synthesis
54             )
55
56         model.train(t)
57
58
59 ######################################################################
60
61
62 class Task:
63     def batches(self, split="train"):
64         pass
65
66     def vocabulary_size(self):
67         pass
68
69     def produce_results(
70         self, n_epoch, model, result_dir, logger, deterministic_synthesis
71     ):
72         pass
73
74
75 ####################
76
77 import problems
78
79
80 class SandBox(Task):
81     def __init__(
82         self,
83         problem,
84         nb_train_samples,
85         nb_test_samples,
86         batch_size,
87         logger=None,
88         device=torch.device("cpu"),
89         max_nb_codes=1024,
90     ):
91         super().__init__()
92
93         self.batch_size = batch_size
94         self.device = device
95         self.problem = problem
96
97         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
98             nb_train_samples
99         )
100         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
101             nb_test_samples
102         )
103
104         self.train_input, self.train_ar_mask = self.train_input.to(
105             device
106         ), self.train_ar_mask.to(device)
107         self.test_input, self.test_ar_mask = self.test_input.to(
108             device
109         ), self.test_ar_mask.to(device)
110
111         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
112
113         # A bit of paranoia never hurts
114         assert (
115             self.nb_codes <= max_nb_codes
116             and self.train_input.min() >= 0
117             and self.test_input.min() >= 0
118             and tuple(self.train_ar_mask.unique()) == (0, 1)
119             and tuple(self.test_ar_mask.unique()) == (0, 1)
120         )
121
122     def batches(self, split="train", nb_to_use=-1, desc=None):
123         assert split in {"train", "test"}
124         input = self.train_input if split == "train" else self.test_input
125         if nb_to_use > 0:
126             input = input[:nb_to_use]
127         if desc is None:
128             desc = f"epoch-{split}"
129         for batch in tqdm.tqdm(
130             input.split(self.batch_size), dynamic_ncols=True, desc=desc
131         ):
132             yield batch
133
134     def vocabulary_size(self):
135         return self.nb_codes
136
137     def produce_results(
138         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
139     ):
140         def compute_accuracy(input, ar_mask, logger=None):
141             input, ar_mask = input[:nmax], ar_mask[:nmax]
142             result = input.clone() * (1 - ar_mask)
143
144             masked_inplace_autoregression(
145                 model,
146                 self.batch_size,
147                 result,
148                 ar_mask,
149                 deterministic_synthesis,
150                 progress_bar_desc=None,
151                 device=self.device,
152             )
153
154             if logger is not None:
155                 for sp, st in zip(result[:10], input[:10]):
156                     logger(
157                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
158                     )
159                     logger(
160                         f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
161                     )
162
163             nb_total = ar_mask.sum().item()
164             nb_correct = ((result == input).long() * ar_mask).sum().item()
165
166             return nb_total, nb_correct
167
168         train_nb_total, train_nb_correct = compute_accuracy(
169             self.train_input, self.train_ar_mask
170         )
171
172         logger(
173             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
174         )
175
176         test_nb_total, test_nb_correct = compute_accuracy(
177             self.test_input, self.test_ar_mask, logger
178         )
179
180         logger(
181             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
182         )
183
184         if save_attention_image is not None:
185             for k in range(10):
186                 ns = torch.randint(self.test_input.size(0), (1,)).item()
187                 input = self.test_input[ns : ns + 1].clone()
188
189                 with torch.autograd.no_grad():
190                     t = model.training
191                     model.eval()
192                     model.record_attention(True)
193                     model(BracketedSequence(input))
194                     model.train(t)
195                     ram = model.retrieve_attention()
196                     model.record_attention(False)
197
198                 tokens_output = [c for c in self.problem.seq2str(input[0])]
199                 tokens_input = ["n/a"] + tokens_output[:-1]
200                 for n_head in range(ram[0].size(1)):
201                     filename = os.path.join(
202                         result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
203                     )
204                     attention_matrices = [m[0, n_head] for m in ram]
205                     save_attention_image(
206                         filename,
207                         tokens_input,
208                         tokens_output,
209                         attention_matrices,
210                         k_top=10,
211                         # min_total_attention=0.9,
212                         token_gap=12,
213                         layer_gap=50,
214                     )
215                     logger(f"wrote {filename}")
216
217
218 ######################################################################
219
220 import picoclvr
221
222
223 class PicoCLVR(Task):
224     # Make a tensor from a list of strings
225     def tensorize(self, descr):
226         token_descr = [s.strip().split(" ") for s in descr]
227         l = max([len(s) for s in token_descr])
228         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
229         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
230         return torch.tensor(id_descr, device=self.device)
231
232     # Make a list of strings from a tensor
233     def detensorize(self, x):
234         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
235
236     # trim all the tensors in the tuple z to remove as much token from
237     # left and right in the first tensor. If z is a tuple, all its
238     # elements are trimed according to the triming for the first
239     def trim(self, z, token="<nul>"):
240         n = self.token2id[token]
241         if type(z) == tuple:
242             x = z[0]
243             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
244             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
245             return tuple([t[:, a:b] for t in z])
246         else:
247             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
248             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
249             return z[:, a:b]
250
251     ######################
252
253     def __init__(
254         self,
255         nb_train_samples,
256         nb_test_samples,
257         batch_size,
258         height,
259         width,
260         nb_colors=5,
261         logger=None,
262         device=torch.device("cpu"),
263         pruner_train=None,
264         pruner_eval=None,
265     ):
266         super().__init__()
267
268         def generate_descr(nb, cache_suffix, pruner):
269             return picoclvr.generate(
270                 nb,
271                 height=self.height,
272                 width=self.width,
273                 nb_colors=nb_colors,
274                 pruner=pruner,
275             )
276
277         self.height = height
278         self.width = width
279         self.batch_size = batch_size
280         self.device = device
281         self.pruner_train = pruner_train
282         self.pruner_eval = pruner_eval
283
284         if logger is not None:
285             logger(
286                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
287             )
288
289         self.train_descr = generate_descr(
290             nb_train_samples, "train", pruner=self.pruner_train
291         )
292         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
293
294         # Build the tokenizer
295         tokens = {"<nul>", "<img>"}
296         for d in [self.train_descr, self.test_descr]:
297             for s in d:
298                 for t in s.strip().split(" "):
299                     tokens.add(t)
300         # make this set a sorted list to get the same tensors given
301         # the same descr
302         tokens = list(tokens)
303         tokens.sort()
304         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
305         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
306         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
307
308         # Tokenize the train and test sets
309         self.train_input = self.tensorize(self.train_descr)
310         self.test_input = self.tensorize(self.test_descr)
311
312     def batches(self, split="train"):
313         assert split in {"train", "test"}
314         input = self.train_input if split == "train" else self.test_input
315         for batch in tqdm.tqdm(
316             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
317         ):
318             yield self.trim(batch)
319
320     def vocabulary_size(self):
321         return len(self.token2id)
322
323     def compute_missing_properties(
324         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
325     ):
326         acc_nb_requested_properties = []
327         acc_nb_missing_properties = []
328         acc_nb_results = 0
329
330         for input in tqdm.tqdm(
331             self.test_input.split(self.batch_size),
332             dynamic_ncols=True,
333             desc=f"test-properties",
334         ):
335             result = input.clone()
336             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
337             result = (1 - ar_mask) * result + ar_mask * self.t_nul
338             masked_inplace_autoregression(
339                 model,
340                 self.batch_size,
341                 result,
342                 ar_mask,
343                 deterministic_synthesis,
344                 progress_bar_desc=None,
345                 device=self.device,
346             )
347
348             result_descr = self.detensorize(result)
349             np = picoclvr.nb_properties(
350                 result_descr,
351                 height=self.height,
352                 width=self.width,
353                 pruner=pruner,
354             )
355             nb_requested_properties, _, nb_missing_properties = zip(*np)
356             acc_nb_requested_properties += nb_requested_properties
357             acc_nb_missing_properties += nb_missing_properties
358             acc_nb_results += len(result_descr)
359
360         nb_requested_properties = sum(acc_nb_requested_properties)
361         nb_missing_properties = sum(acc_nb_missing_properties)
362
363         prefix = "" if pruner is None else "pruned_"
364         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
365         logger(
366             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
367         )
368         logger(
369             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
370         )
371
372     ######################################################################
373
374     def produce_results(
375         self, n_epoch, model, result_dir, logger, deterministic_synthesis
376     ):
377         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
378
379         if self.pruner_eval is not None:
380             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
381
382         nb_tokens_to_generate = self.height * self.width + 3
383         result_descr = []
384         nb_per_primer = 8
385         primer = []
386
387         for primer_descr in [
388             "red above green <sep> green top <sep> blue right of red",
389             "there is red <sep> there is yellow <sep> there is blue",
390             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
391             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
392         ]:
393             primer += [primer_descr + " <img>"] * nb_per_primer
394
395         result = self.tensorize(primer)
396         fill = result.new_full(
397             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
398         )
399         result = torch.cat((result, fill), 1)
400         ar_mask = (result == self.t_nul).long()
401         masked_inplace_autoregression(
402             model,
403             self.batch_size,
404             result,
405             ar_mask,
406             deterministic_synthesis,
407             device=self.device,
408         )
409         result_descr = self.detensorize(result)
410
411         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
412
413         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
414         acc_nb_results = len(result_descr)
415
416         nb_requested_properties = sum(acc_nb_requested_properties)
417         nb_missing_properties = sum(acc_nb_missing_properties)
418
419         prefix = "demo_"
420         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
421         logger(
422             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
423         )
424         logger(
425             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
426         )
427
428         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
429
430         if img.dim() == 5:
431             if img.size(1) == 1:
432                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
433             else:
434                 img = torch.cat(
435                     [
436                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
437                         for x in img
438                     ],
439                     0,
440                 )
441
442         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
443         torchvision.utils.save_image(
444             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
445         )
446         logger(f"wrote {image_name}")
447
448
449 ######################################################################
450
451
452 class MNIST(Task):
453     def __init__(
454         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
455     ):
456         super().__init__()
457
458         self.nb_train_samples = (nb_train_samples,)
459         self.nb_test_samples = (nb_test_samples,)
460         self.batch_size = batch_size
461         self.device = device
462         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
463         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
464         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
465         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
466
467     def batches(self, split="train", nb_to_use=-1, desc=None):
468         assert split in {"train", "test"}
469         input = self.train_input if split == "train" else self.test_input
470         if nb_to_use > 0:
471             input = input[:nb_to_use]
472         if desc is None:
473             desc = f"epoch-{split}"
474         for batch in tqdm.tqdm(
475             input.split(self.batch_size), dynamic_ncols=True, desc=desc
476         ):
477             yield batch
478
479     def vocabulary_size(self):
480         return 256
481
482     def produce_results(
483         self, n_epoch, model, result_dir, logger, deterministic_synthesis
484     ):
485         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
486         ar_mask = torch.full_like(results, 1)
487         masked_inplace_autoregression(
488             model,
489             self.batch_size,
490             results,
491             ar_mask,
492             deterministic_synthesis,
493             device=self.device,
494         )
495         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
496         torchvision.utils.save_image(
497             1 - results.reshape(-1, 1, 28, 28) / 255.0,
498             image_name,
499             nrow=16,
500             pad_value=0.8,
501         )
502         logger(f"wrote {image_name}")
503
504
505 ######################################################################
506
507 import maze
508
509
510 class Maze(Task):
511     def map2seq(self, *m):
512         return torch.cat([x.flatten(1) for x in m], 1)
513
514     def seq2map(self, s):
515         s = s.reshape(s.size(0), -1, self.height, self.width)
516         return (s[:, k] for k in range(s.size(1)))
517
518     def __init__(
519         self,
520         nb_train_samples,
521         nb_test_samples,
522         batch_size,
523         height,
524         width,
525         nb_walls,
526         device=torch.device("cpu"),
527     ):
528         super().__init__()
529
530         self.batch_size = batch_size
531         self.height = height
532         self.width = width
533         self.device = device
534
535         train_mazes, train_paths, _ = maze.create_maze_data(
536             nb_train_samples,
537             height=height,
538             width=width,
539             nb_walls=nb_walls,
540             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
541         )
542         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
543
544         test_mazes, test_paths, _ = maze.create_maze_data(
545             nb_test_samples,
546             height=height,
547             width=width,
548             nb_walls=nb_walls,
549             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
550         )
551         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
552
553         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
554
555     def batches(self, split="train", nb_to_use=-1, desc=None):
556         assert split in {"train", "test"}
557         input = self.train_input if split == "train" else self.test_input
558         if nb_to_use > 0:
559             input = input[:nb_to_use]
560         if desc is None:
561             desc = f"epoch-{split}"
562         for batch in tqdm.tqdm(
563             input.split(self.batch_size), dynamic_ncols=True, desc=desc
564         ):
565             yield batch
566
567     def vocabulary_size(self):
568         return self.nb_codes
569
570     def compute_error(
571         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
572     ):
573         nb_total, nb_correct = 0, 0
574         count = torch.zeros(
575             self.width * self.height,
576             self.width * self.height,
577             device=self.device,
578             dtype=torch.int64,
579         )
580
581         for input in self.batches(split, nb_to_use):
582             result = input.clone()
583             ar_mask = result.new_zeros(result.size())
584             ar_mask[:, self.height * self.width :] = 1
585             result *= 1 - ar_mask
586             masked_inplace_autoregression(
587                 model,
588                 self.batch_size,
589                 result,
590                 ar_mask,
591                 deterministic_synthesis,
592                 progress_bar_desc=None,
593                 device=self.device,
594             )
595             mazes, paths = self.seq2map(result)
596             path_correctness = maze.path_correctness(mazes, paths)
597             nb_correct += path_correctness.long().sum()
598             nb_total += mazes.size(0)
599
600             optimal_path_lengths = (
601                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
602             )
603             predicted_path_lengths = (
604                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
605             )
606             optimal_path_lengths = optimal_path_lengths[path_correctness]
607             predicted_path_lengths = predicted_path_lengths[path_correctness]
608             count[optimal_path_lengths, predicted_path_lengths] += 1
609
610         if count.max() == 0:
611             count = None
612         else:
613             count = count[
614                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
615             ]
616
617         return nb_total, nb_correct, count
618
619     def produce_results(
620         self, n_epoch, model, result_dir, logger, deterministic_synthesis
621     ):
622         train_nb_total, train_nb_correct, count = self.compute_error(
623             model,
624             "train",
625             nb_to_use=1000,
626             deterministic_synthesis=deterministic_synthesis,
627         )
628         logger(
629             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
630         )
631
632         test_nb_total, test_nb_correct, count = self.compute_error(
633             model,
634             "test",
635             nb_to_use=1000,
636             deterministic_synthesis=deterministic_synthesis,
637         )
638         logger(
639             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
640         )
641
642         if count is not None:
643             proportion_optimal = count.diagonal().sum().float() / count.sum()
644             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
645             with open(
646                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
647             ) as f:
648                 for i in range(count.size(0)):
649                     for j in range(count.size(1)):
650                         eol = " " if j < count.size(1) - 1 else "\n"
651                         f.write(f"{count[i,j]}{eol}")
652
653         input = self.test_input[:48]
654         result = input.clone()
655         ar_mask = result.new_zeros(result.size())
656         ar_mask[:, self.height * self.width :] = 1
657         result *= 1 - ar_mask
658         masked_inplace_autoregression(
659             model,
660             self.batch_size,
661             result,
662             ar_mask,
663             deterministic_synthesis,
664             device=self.device,
665         )
666
667         mazes, paths = self.seq2map(input)
668         _, predicted_paths = self.seq2map(result)
669
670         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
671         maze.save_image(
672             filename,
673             mazes=mazes,
674             target_paths=paths,
675             predicted_paths=predicted_paths,
676             path_correct=maze.path_correctness(mazes, predicted_paths),
677             path_optimal=maze.path_optimality(paths, predicted_paths),
678         )
679         logger(f"wrote {filename}")
680
681
682 ######################################################################
683
684
685 import snake
686
687
688 class Snake(Task):
689     def __init__(
690         self,
691         nb_train_samples,
692         nb_test_samples,
693         batch_size,
694         height,
695         width,
696         nb_colors,
697         length,
698         prompt_length,
699         device=torch.device("cpu"),
700     ):
701         super().__init__()
702
703         self.batch_size = batch_size
704         self.height = height
705         self.width = width
706         self.device = device
707         self.prompt_length = prompt_length
708
709         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
710             nb_train_samples,
711             height,
712             width,
713             nb_colors,
714             length,
715             prompt_length,
716             self.device,
717         )
718         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
719             nb_test_samples,
720             height,
721             width,
722             nb_colors,
723             length,
724             prompt_length,
725             self.device,
726         )
727
728         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
729
730     def batches(self, split="train", nb_to_use=-1, desc=None):
731         assert split in {"train", "test"}
732         input = self.train_input if split == "train" else self.test_input
733         if nb_to_use > 0:
734             input = input[:nb_to_use]
735         if desc is None:
736             desc = f"epoch-{split}"
737         for batch in tqdm.tqdm(
738             input.split(self.batch_size), dynamic_ncols=True, desc=desc
739         ):
740             yield batch
741
742     def vocabulary_size(self):
743         return self.nb_codes
744
745     def produce_results(
746         self, n_epoch, model, result_dir, logger, deterministic_synthesis
747     ):
748         def compute_nb_correct(input, prior_visits):
749             result = input.clone()
750             i = torch.arange(result.size(1), device=result.device)[None, :]
751             ar_mask = (
752                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
753                 .long()
754                 .expand_as(result)
755             )
756             result *= 1 - ar_mask
757
758             masked_inplace_autoregression(
759                 model,
760                 self.batch_size,
761                 result,
762                 ar_mask,
763                 deterministic_synthesis,
764                 device=self.device,
765             )
766
767             nb_total = ((prior_visits > 0) * ar_mask).sum()
768
769             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
770
771             return nb_total, nb_correct
772
773         test_nb_total, test_nb_correct = compute_nb_correct(
774             self.test_input[:1000], self.test_prior_visits[:1000]
775         )
776
777         logger(
778             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
779         )
780
781
782 ######################################################################
783
784
785 import stack
786
787
788 class Stack(Task):
789     def __init__(
790         self,
791         nb_train_samples,
792         nb_test_samples,
793         batch_size,
794         logger,
795         nb_steps,
796         nb_stacks,
797         nb_digits,
798         fraction_values_for_train=None,
799         device=torch.device("cpu"),
800     ):
801         super().__init__()
802
803         self.batch_size = batch_size
804         self.nb_steps = nb_steps
805         self.nb_stacks = nb_stacks
806         self.nb_digits = nb_digits
807         self.device = device
808
809         if fraction_values_for_train is None:
810             values_for_train = None
811             values_for_test = None
812         else:
813             all = torch.randperm(10**nb_digits)
814             nb_for_train = int(all.size(0) * fraction_values_for_train)
815             values_for_train = all[:nb_for_train]
816             values_for_test = all[nb_for_train:]
817
818         self.train_input, self.train_stack_counts = stack.generate_sequences(
819             nb_train_samples,
820             nb_steps,
821             nb_stacks,
822             nb_digits,
823             values_for_train,
824             self.device,
825         )
826
827         self.test_input, self.test_stack_counts = stack.generate_sequences(
828             nb_test_samples,
829             nb_steps,
830             nb_stacks,
831             nb_digits,
832             values_for_test,
833             self.device,
834         )
835
836         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
837         counts = self.test_stack_counts.flatten()[i.flatten()]
838         counts = F.one_hot(counts).sum(0)
839         logger(f"test_pop_stack_counts {counts}")
840
841         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
842
843     def batches(self, split="train", nb_to_use=-1, desc=None):
844         assert split in {"train", "test"}
845         input = self.train_input if split == "train" else self.test_input
846         if nb_to_use > 0:
847             input = input[:nb_to_use]
848         if desc is None:
849             desc = f"epoch-{split}"
850         for batch in tqdm.tqdm(
851             input.split(self.batch_size), dynamic_ncols=True, desc=desc
852         ):
853             yield batch
854
855     def vocabulary_size(self):
856         return self.nb_codes
857
858     def produce_results(
859         self, n_epoch, model, result_dir, logger, deterministic_synthesis
860     ):
861         def compute_nb_correct(input):
862             result = input.clone()
863             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
864             ar_mask = (result != input).long()
865             masked_inplace_autoregression(
866                 model,
867                 self.batch_size,
868                 result,
869                 ar_mask,
870                 deterministic_synthesis,
871                 device=self.device,
872             )
873
874             errors = ((result != input).long() * ar_mask).reshape(
875                 -1, 1 + self.nb_digits
876             )
877             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
878
879             nb_total = ar_mask.max(1).values.sum()
880             nb_correct = nb_total - errors.max(1).values.sum()
881
882             return nb_total, nb_correct
883
884         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
885
886         logger(
887             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
888         )
889
890         ##############################################################
891         # Log a few generated sequences
892         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
893         result = input.clone()
894         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
895         ar_mask = (result != input).long()
896
897         # for n in range(result.size(0)):
898         # logger(
899         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
900         # )
901
902         masked_inplace_autoregression(
903             model,
904             self.batch_size,
905             result,
906             ar_mask,
907             deterministic_synthesis,
908             device=self.device,
909         )
910
911         for n in range(result.size(0)):
912             logger(
913                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
914             )
915         ##############################################################
916
917
918 ######################################################################
919
920 import rpl
921
922
923 class RPL(Task):
924     def tensorize(self, sequences):
925         len_max = max([len(x) for x in sequences])
926         return torch.cat(
927             [
928                 torch.tensor(
929                     [
930                         [
931                             self.token2id[str(c)]
932                             for c in s + ["<nul>"] * (len_max - len(s))
933                         ]
934                         for s in sequences
935                     ]
936                 )
937             ],
938             0,
939         )
940
941     def seq2str(self, seq):
942         return " ".join([self.id2token[i] for i in seq])
943
944     def __init__(
945         self,
946         nb_train_samples,
947         nb_test_samples,
948         batch_size,
949         nb_starting_values=3,
950         max_input=9,
951         prog_len=6,
952         nb_runs=5,
953         no_prog=False,
954         logger=None,
955         device=torch.device("cpu"),
956     ):
957         super().__init__()
958
959         self.batch_size = batch_size
960         self.device = device
961         self.no_prog = no_prog
962
963         train_sequences = [
964             rpl.generate(
965                 nb_starting_values=nb_starting_values,
966                 nb_result_values_max=4 * nb_starting_values,
967                 max_input=max_input,
968                 prog_len=prog_len,
969                 nb_runs=nb_runs,
970             )
971             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
972         ]
973
974         test_sequences = [
975             rpl.generate(
976                 nb_starting_values=nb_starting_values,
977                 nb_result_values_max=4 * nb_starting_values,
978                 max_input=max_input,
979                 prog_len=prog_len,
980                 nb_runs=nb_runs,
981             )
982             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
983         ]
984
985         symbols = list(
986             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
987         )
988         val_max = max([x if type(x) is int else 0 for x in symbols])
989         symbols = list(filter(lambda x: type(x) is str, symbols))
990         symbols.sort()
991         symbols += [str(n) for n in range(val_max + 1)]
992         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
993         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
994
995         self.t_nul = self.token2id["<nul>"]
996         self.t_input = self.token2id["<in>"]
997         self.t_output = self.token2id["<out>"]
998         self.t_prog = self.token2id["<prg>"]
999         self.t_end = self.token2id["<end>"]
1000
1001         self.train_input = self.tensorize(train_sequences)
1002         self.test_input = self.tensorize(test_sequences)
1003
1004         if no_prog:
1005             # Excise the program from every train and test example
1006             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1007                 None, :
1008             ]
1009             p = (
1010                 ((self.train_input == self.t_prog).long() * k)
1011                 .max(1, keepdim=True)
1012                 .values
1013             )
1014             self.train_input = (
1015                 self.train_input * (k <= p).long()
1016                 + self.t_end * (k == p + 1).long()
1017                 + self.t_nul * (k > p + 1).long()
1018             )
1019             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1020                 None, :
1021             ]
1022             p = (
1023                 ((self.test_input == self.t_prog).long() * k)
1024                 .max(1, keepdim=True)
1025                 .values
1026             )
1027             self.test_input = (
1028                 self.test_input * (k <= p).long()
1029                 + self.t_end * (k == p + 1).long()
1030                 + self.t_nul * (k > p + 1).long()
1031             )
1032
1033         if logger is not None:
1034             logger(f"value_max {val_max}")
1035             for x in self.train_input[:25]:
1036                 end = (x != self.t_nul).nonzero().max().item() + 1
1037                 seq = [self.id2token[i.item()] for i in x[:end]]
1038                 s = " ".join(seq)
1039                 logger(f"example_seq {s}")
1040
1041         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1042
1043     def batches(self, split="train", nb_to_use=-1, desc=None):
1044         assert split in {"train", "test"}
1045         input = self.train_input if split == "train" else self.test_input
1046         if nb_to_use > 0:
1047             input = input[:nb_to_use]
1048         if desc is None:
1049             desc = f"epoch-{split}"
1050         for batch in tqdm.tqdm(
1051             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1052         ):
1053             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1054             batch = batch[:, :last].to(self.device)
1055             yield batch
1056
1057     def vocabulary_size(self):
1058         return self.nb_codes
1059
1060     def produce_results(
1061         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1062     ):
1063         # --------------------------------------------------------------------
1064         def compute_nb_errors_prog(input, nb_to_log=0):
1065             result = input.clone()
1066             s = (result == self.t_prog).long()
1067             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1068             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1069
1070             masked_inplace_autoregression(
1071                 model,
1072                 self.batch_size,
1073                 result,
1074                 ar_mask,
1075                 deterministic_synthesis,
1076                 device=self.device,
1077             )
1078
1079             sum_nb_total, sum_nb_errors = 0, 0
1080             for one_input, one_result in zip(input, result):
1081                 seq = [self.id2token[i.item()] for i in one_result]
1082                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1083                 sum_nb_total += 1
1084                 sum_nb_errors += 0 if nb_errors == 0 else 1
1085                 if nb_to_log > 0:
1086                     gt_seq = [self.id2token[i.item()] for i in one_input]
1087                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1088                     gt_prog = " ".join([str(x) for x in gt_prog])
1089                     prog = " ".join([str(x) for x in prog])
1090                     comment = "*" if nb_errors == 0 else "-"
1091                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1092                     for start_stack, target_stack, result_stack, correct in stacks:
1093                         comment = "*" if correct else "-"
1094                         start_stack = " ".join([str(x) for x in start_stack])
1095                         target_stack = " ".join([str(x) for x in target_stack])
1096                         result_stack = " ".join([str(x) for x in result_stack])
1097                         logger(
1098                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1099                         )
1100                     nb_to_log -= 1
1101
1102             return sum_nb_total, sum_nb_errors
1103
1104         # --------------------------------------------------------------------
1105         def compute_nb_errors_output(input, nb_to_log=0):
1106             result = input.clone()
1107             k = torch.arange(result.size(1), device=result.device)[None, :]
1108             last_output_idx = (
1109                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1110             )
1111             first_prog_idx = (
1112                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1113             )
1114             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1115             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1116
1117             masked_inplace_autoregression(
1118                 model,
1119                 self.batch_size,
1120                 result,
1121                 ar_mask,
1122                 deterministic_synthesis,
1123                 device=self.device,
1124             )
1125
1126             sum_nb_total, sum_nb_errors = 0, 0
1127             for one_input, one_result, i, j in zip(
1128                 input, result, last_output_idx, first_prog_idx
1129             ):
1130                 seq = [self.id2token[i.item()] for i in one_result]
1131                 sum_nb_total += 1
1132                 correct = (one_input - one_result).abs().max() == 0
1133                 sum_nb_errors += 0 if correct else 1
1134                 if nb_to_log > 0:
1135                     result_stack = [
1136                         self.id2token[i.item()] for i in one_result[i : j + 1]
1137                     ]
1138                     target_stack = [
1139                         self.id2token[i.item()] for i in one_input[i : j + 1]
1140                     ]
1141                     comment = "*" if correct else "-"
1142                     result_stack = " ".join([str(x) for x in result_stack])
1143                     target_stack = " ".join([str(x) for x in target_stack])
1144                     logger(
1145                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1146                     )
1147                     nb_to_log -= 1
1148
1149             return sum_nb_total, sum_nb_errors
1150
1151         # --------------------------------------------------------------------
1152
1153         if not self.no_prog:
1154             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1155                 self.test_input[:1000].to(self.device), nb_to_log=10
1156             )
1157
1158             logger(
1159                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1160             )
1161
1162         test_nb_total, test_nb_errors = compute_nb_errors_output(
1163             self.test_input[:1000].to(self.device), nb_to_log=10
1164         )
1165
1166         logger(
1167             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1168         )
1169
1170         if save_attention_image is not None:
1171             ns = torch.randint(self.test_input.size(0), (1,)).item()
1172             input = self.test_input[ns : ns + 1].clone()
1173             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1174             input = input[:, :last].to(self.device)
1175
1176             with torch.autograd.no_grad():
1177                 t = model.training
1178                 model.eval()
1179                 model.record_attention(True)
1180                 model(BracketedSequence(input))
1181                 model.train(t)
1182                 ram = model.retrieve_attention()
1183                 model.record_attention(False)
1184
1185             tokens_output = [self.id2token[i.item()] for i in input[0]]
1186             tokens_input = ["n/a"] + tokens_output[:-1]
1187             for n_head in range(ram[0].size(1)):
1188                 filename = os.path.join(
1189                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1190                 )
1191                 attention_matrices = [m[0, n_head] for m in ram]
1192                 save_attention_image(
1193                     filename,
1194                     tokens_input,
1195                     tokens_output,
1196                     attention_matrices,
1197                     k_top=10,
1198                     # min_total_attention=0.9,
1199                     token_gap=12,
1200                     layer_gap=50,
1201                 )
1202                 logger(f"wrote {filename}")
1203
1204
1205 ######################################################################
1206
1207
1208 import expr
1209
1210
1211 class Expr(Task):
1212     def tensorize(self, sequences):
1213         len_max = max([len(x) for x in sequences])
1214         return torch.cat(
1215             [
1216                 torch.tensor(
1217                     [
1218                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1219                         for s in sequences
1220                     ]
1221                 )
1222             ],
1223             0,
1224         ).to(self.device)
1225
1226     def __init__(
1227         self,
1228         nb_train_samples,
1229         nb_test_samples,
1230         nb_variables,
1231         sequence_length,
1232         operand_max,
1233         result_max,
1234         batch_size,
1235         device=torch.device("cpu"),
1236     ):
1237         super().__init__()
1238
1239         self.batch_size = batch_size
1240         self.device = device
1241
1242         train_sequences = expr.generate_sequences(
1243             nb_train_samples,
1244             nb_variables=nb_variables,
1245             length=sequence_length,
1246             operand_max=operand_max,
1247             result_max=result_max,
1248         )
1249
1250         test_sequences = expr.generate_sequences(
1251             nb_test_samples,
1252             nb_variables=nb_variables,
1253             length=sequence_length,
1254             operand_max=operand_max,
1255             result_max=result_max,
1256         )
1257
1258         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1259         symbols.sort()
1260
1261         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1262         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1263
1264         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1265
1266         self.train_input = self.tensorize(train_sequences)
1267         self.test_input = self.tensorize(test_sequences)
1268
1269         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1270
1271     def batches(self, split="train", nb_to_use=-1, desc=None):
1272         assert split in {"train", "test"}
1273         input = self.train_input if split == "train" else self.test_input
1274         if nb_to_use > 0:
1275             input = input[:nb_to_use]
1276         if desc is None:
1277             desc = f"epoch-{split}"
1278         for batch in tqdm.tqdm(
1279             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1280         ):
1281             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1282             batch = batch[:, :last]
1283             yield batch
1284
1285     def vocabulary_size(self):
1286         return self.nb_codes
1287
1288     def seq2str(self, s):
1289         return "".join([self.id2char[k.item()] for k in s])
1290
1291     def produce_results(
1292         self,
1293         n_epoch,
1294         model,
1295         result_dir,
1296         logger,
1297         deterministic_synthesis,
1298         input_file=None,
1299     ):
1300         def compute_nb_correct(input):
1301             result = input.clone()
1302             s = (result == self.space).long()
1303             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1304             result = (1 - ar_mask) * result + ar_mask * self.filler
1305             masked_inplace_autoregression(
1306                 model,
1307                 self.batch_size,
1308                 result,
1309                 ar_mask,
1310                 deterministic_synthesis,
1311                 device=self.device,
1312             )
1313
1314             nb_total = input.size(0)
1315             nb_correct = (input == result).long().min(1).values.sum()
1316
1317             #######################################################################
1318             # Comput predicted vs. true variable values
1319
1320             nb_delta = torch.zeros(5, dtype=torch.int64)
1321             nb_missed = 0
1322
1323             values_input = expr.extract_results([self.seq2str(s) for s in input])
1324             values_result = expr.extract_results([self.seq2str(s) for s in result])
1325
1326             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1327
1328             with open(filename, "w") as f:
1329                 for i, r in zip(values_input, values_result):
1330                     for n, vi in i.items():
1331                         vr = r.get(n)
1332                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1333
1334                         if vr is None or vr < 0:
1335                             nb_missed += 1
1336                         else:
1337                             d = abs(vr - vi)
1338                             if d >= nb_delta.size(0):
1339                                 nb_missed += 1
1340                             else:
1341                                 nb_delta[d] += 1
1342
1343             ######################################################################
1344
1345             return nb_total, nb_correct, nb_delta, nb_missed
1346
1347         (
1348             test_nb_total,
1349             test_nb_correct,
1350             test_nb_delta,
1351             test_nb_missed,
1352         ) = compute_nb_correct(self.test_input[:10000])
1353
1354         logger(
1355             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1356         )
1357
1358         nb_total = test_nb_delta.sum() + test_nb_missed
1359         for d in range(test_nb_delta.size(0)):
1360             logger(
1361                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1362             )
1363         logger(
1364             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1365         )
1366
1367         ##############################################################
1368         # Log a few generated sequences
1369         if input_file is None:
1370             input = self.test_input[:10]
1371         else:
1372             with open(input_file, "r") as f:
1373                 sequences = [e.strip() for e in f.readlines()]
1374                 sequences = [s + " " + "#" * 50 for s in sequences]
1375                 input = self.tensorize(sequences)
1376
1377         result = input.clone()
1378         s = (result == self.space).long()
1379         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1380         result = (1 - ar_mask) * result + ar_mask * self.filler
1381
1382         for n in range(result.size(0)):
1383             logger(f"test_before {self.seq2str(result[n])}")
1384
1385         masked_inplace_autoregression(
1386             model,
1387             self.batch_size,
1388             result,
1389             ar_mask,
1390             deterministic_synthesis,
1391             device=self.device,
1392         )
1393
1394         correct = (1 - ar_mask) * self.space + ar_mask * input
1395         for n in range(result.size(0)):
1396             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1397             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1398             logger(f"truth       {self.seq2str(correct[n])}")
1399         ##############################################################
1400
1401
1402 ######################################################################
1403
1404 import world
1405
1406
1407 class World(Task):
1408     def __init__(
1409         self,
1410         nb_train_samples,
1411         nb_test_samples,
1412         batch_size,
1413         vqae_nb_epochs,
1414         logger=None,
1415         device=torch.device("cpu"),
1416         device_storage=torch.device("cpu"),
1417     ):
1418         super().__init__()
1419
1420         self.batch_size = batch_size
1421         self.device = device
1422
1423         (
1424             train_frames,
1425             train_action_seq,
1426             test_frames,
1427             test_action_seq,
1428             self.frame2seq,
1429             self.seq2frame,
1430         ) = world.create_data_and_processors(
1431             nb_train_samples,
1432             nb_test_samples,
1433             mode="first_last",
1434             nb_steps=30,
1435             nb_epochs=vqae_nb_epochs,
1436             logger=logger,
1437             device=device,
1438             device_storage=device_storage,
1439         )
1440
1441         train_frame_seq = self.frame2seq(train_frames).to(device_storage)
1442         test_frame_seq = self.frame2seq(test_frames).to(device_storage)
1443
1444         nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
1445         nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
1446
1447         self.len_frame_seq = train_frame_seq.size(1)
1448         self.len_action_seq = train_action_seq.size(1)
1449         self.nb_codes = nb_frame_codes + nb_action_codes
1450
1451         train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
1452
1453         train_action_seq += nb_frame_codes
1454         self.train_input = torch.cat(
1455             (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
1456         )
1457
1458         test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
1459         test_action_seq += nb_frame_codes
1460         self.test_input = torch.cat(
1461             (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
1462         )
1463
1464     def batches(self, split="train", nb_to_use=-1, desc=None):
1465         assert split in {"train", "test"}
1466         input = self.train_input if split == "train" else self.test_input
1467         if nb_to_use > 0:
1468             input = input[:nb_to_use]
1469         if desc is None:
1470             desc = f"epoch-{split}"
1471         for batch in tqdm.tqdm(
1472             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1473         ):
1474             yield batch.to(self.device)
1475
1476     def vocabulary_size(self):
1477         return self.nb_codes
1478
1479     def produce_results(
1480         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1481     ):
1482         k = torch.arange(
1483             2 * self.len_frame_seq + self.len_action_seq, device=self.device
1484         )[None, :]
1485
1486         input = self.test_input[:64].to(self.device)
1487         result = input.clone()
1488
1489         ar_mask = (
1490             (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
1491         )
1492         result *= 1 - ar_mask
1493
1494         masked_inplace_autoregression(
1495             model,
1496             self.batch_size,
1497             result,
1498             ar_mask,
1499             deterministic_synthesis,
1500             device=self.device,
1501         )
1502
1503         seq_start = input[:, : self.len_frame_seq]
1504         seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
1505         seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
1506
1507         result = torch.cat(
1508             (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
1509         )
1510         result = result.reshape(-1, result.size(-1))
1511
1512         frames = self.seq2frame(result)
1513         image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
1514         torchvision.utils.save_image(
1515             frames.float() / (world.Box.nb_rgb_levels - 1),
1516             image_name,
1517             nrow=12,
1518             padding=1,
1519             pad_value=0.0,
1520         )
1521         logger(f"wrote {image_name}")
1522
1523
1524 ######################################################################