cbc8e6b05c81aab7f6c696b8304574756619ac38
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 try:
18     from graph import save_attention_image
19 except ImportError:
20     save_attention_image = None
21
22 ######################################################################
23
24
25 def masked_inplace_autoregression(
26     model,
27     batch_size,
28     input,
29     ar_mask,
30     deterministic_synthesis,
31     forbidden_tokens=None,
32     progress_bar_desc="autoregression",
33     device=torch.device("cpu"),
34 ):
35     assert input.size() == ar_mask.size()
36
37     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
38
39     if progress_bar_desc is not None:
40         batches = tqdm.tqdm(
41             batches,
42             dynamic_ncols=True,
43             desc=progress_bar_desc,
44             total=(input.size(0) + batch_size - 1) // batch_size,
45         )
46
47     with torch.autograd.no_grad():
48         t = model.training
49         model.eval()
50
51         for input, ar_mask in batches:
52             model.masked_inplace_autoregression(
53                 input, ar_mask, forbidden_tokens, deterministic_synthesis
54             )
55
56         model.train(t)
57
58
59 ######################################################################
60
61
62 class Task:
63     def batches(self, split="train"):
64         pass
65
66     def vocabulary_size(self):
67         pass
68
69     def produce_results(
70         self, n_epoch, model, result_dir, logger, deterministic_synthesis
71     ):
72         pass
73
74
75 ####################
76
77 import problems
78
79
80 class SandBox(Task):
81     def __init__(
82         self,
83         problem,
84         nb_train_samples,
85         nb_test_samples,
86         batch_size,
87         logger=None,
88         device=torch.device("cpu"),
89         max_nb_codes=1024,
90     ):
91         super().__init__()
92
93         self.batch_size = batch_size
94         self.device = device
95         self.problem = problem
96
97         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
98             nb_train_samples
99         )
100         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
101             nb_test_samples
102         )
103
104         self.train_input, self.train_ar_mask = self.train_input.to(
105             device
106         ), self.train_ar_mask.to(device)
107         self.test_input, self.test_ar_mask = self.test_input.to(
108             device
109         ), self.test_ar_mask.to(device)
110
111         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
112
113         # A bit of paranoia never hurts
114         assert (
115             self.nb_codes <= max_nb_codes
116             and self.train_input.min() >= 0
117             and self.test_input.min() >= 0
118             and tuple(self.train_ar_mask.unique()) == (0, 1)
119             and tuple(self.test_ar_mask.unique()) == (0, 1)
120         )
121
122     def batches(self, split="train", nb_to_use=-1, desc=None):
123         assert split in {"train", "test"}
124         input = self.train_input if split == "train" else self.test_input
125         if nb_to_use > 0:
126             input = input[:nb_to_use]
127         if desc is None:
128             desc = f"epoch-{split}"
129         for batch in tqdm.tqdm(
130             input.split(self.batch_size), dynamic_ncols=True, desc=desc
131         ):
132             yield batch
133
134     def vocabulary_size(self):
135         return self.nb_codes
136
137     def produce_results(
138         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
139     ):
140         def compute_accuracy(input, ar_mask, logger=None):
141             input, ar_mask = input[:nmax], ar_mask[:nmax]
142             result = input.clone() * (1 - ar_mask)
143
144             masked_inplace_autoregression(
145                 model,
146                 self.batch_size,
147                 result,
148                 ar_mask,
149                 deterministic_synthesis,
150                 progress_bar_desc=None,
151                 device=self.device,
152             )
153
154             if logger is not None:
155                 for sp, st in zip(result[:10], input[:10]):
156                     logger(
157                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
158                     )
159                     logger(
160                         f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
161                     )
162
163             nb_total = ar_mask.sum().item()
164             nb_correct = ((result == input).long() * ar_mask).sum().item()
165
166             return nb_total, nb_correct
167
168         train_nb_total, train_nb_correct = compute_accuracy(
169             self.train_input, self.train_ar_mask
170         )
171
172         logger(
173             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
174         )
175
176         test_nb_total, test_nb_correct = compute_accuracy(
177             self.test_input, self.test_ar_mask, logger
178         )
179
180         logger(
181             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
182         )
183
184         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
185
186         if save_attention_image is None:
187             logger("no save_attention_image (is pycairo installed?)")
188         else:
189             for k in range(10):
190                 ns = torch.randint(self.test_input.size(0), (1,)).item()
191                 input = self.test_input[ns : ns + 1].clone()
192
193                 with torch.autograd.no_grad():
194                     t = model.training
195                     model.eval()
196                     model.record_attention(True)
197                     model(BracketedSequence(input))
198                     model.train(t)
199                     ram = model.retrieve_attention()
200                     model.record_attention(False)
201
202                 tokens_output = [c for c in self.problem.seq2str(input[0])]
203                 tokens_input = ["n/a"] + tokens_output[:-1]
204                 for n_head in range(ram[0].size(1)):
205                     filename = os.path.join(
206                         result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
207                     )
208                     attention_matrices = [m[0, n_head] for m in ram]
209                     save_attention_image(
210                         filename,
211                         tokens_input,
212                         tokens_output,
213                         attention_matrices,
214                         k_top=10,
215                         # min_total_attention=0.9,
216                         token_gap=12,
217                         layer_gap=50,
218                     )
219                     logger(f"wrote {filename}")
220
221
222 ######################################################################
223
224 import picoclvr
225
226
227 class PicoCLVR(Task):
228     # Make a tensor from a list of strings
229     def tensorize(self, descr):
230         token_descr = [s.strip().split(" ") for s in descr]
231         l = max([len(s) for s in token_descr])
232         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
233         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
234         return torch.tensor(id_descr, device=self.device)
235
236     # Make a list of strings from a tensor
237     def detensorize(self, x):
238         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
239
240     # trim all the tensors in the tuple z to remove as much token from
241     # left and right in the first tensor. If z is a tuple, all its
242     # elements are trimed according to the triming for the first
243     def trim(self, z, token="<nul>"):
244         n = self.token2id[token]
245         if type(z) == tuple:
246             x = z[0]
247             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
248             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
249             return tuple([t[:, a:b] for t in z])
250         else:
251             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
252             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
253             return z[:, a:b]
254
255     ######################
256
257     def __init__(
258         self,
259         nb_train_samples,
260         nb_test_samples,
261         batch_size,
262         height,
263         width,
264         nb_colors=5,
265         logger=None,
266         device=torch.device("cpu"),
267         pruner_train=None,
268         pruner_eval=None,
269     ):
270         super().__init__()
271
272         def generate_descr(nb, cache_suffix, pruner):
273             return picoclvr.generate(
274                 nb,
275                 height=self.height,
276                 width=self.width,
277                 nb_colors=nb_colors,
278                 pruner=pruner,
279             )
280
281         self.height = height
282         self.width = width
283         self.batch_size = batch_size
284         self.device = device
285         self.pruner_train = pruner_train
286         self.pruner_eval = pruner_eval
287
288         if logger is not None:
289             logger(
290                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
291             )
292
293         self.train_descr = generate_descr(
294             nb_train_samples, "train", pruner=self.pruner_train
295         )
296         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
297
298         # Build the tokenizer
299         tokens = {"<nul>", "<img>"}
300         for d in [self.train_descr, self.test_descr]:
301             for s in d:
302                 for t in s.strip().split(" "):
303                     tokens.add(t)
304         # make this set a sorted list to get the same tensors given
305         # the same descr
306         tokens = list(tokens)
307         tokens.sort()
308         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
309         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
310         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
311
312         # Tokenize the train and test sets
313         self.train_input = self.tensorize(self.train_descr)
314         self.test_input = self.tensorize(self.test_descr)
315
316     def batches(self, split="train"):
317         assert split in {"train", "test"}
318         input = self.train_input if split == "train" else self.test_input
319         for batch in tqdm.tqdm(
320             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
321         ):
322             yield self.trim(batch)
323
324     def vocabulary_size(self):
325         return len(self.token2id)
326
327     def compute_missing_properties(
328         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
329     ):
330         acc_nb_requested_properties = []
331         acc_nb_missing_properties = []
332         acc_nb_results = 0
333
334         for input in tqdm.tqdm(
335             self.test_input.split(self.batch_size),
336             dynamic_ncols=True,
337             desc=f"test-properties",
338         ):
339             result = input.clone()
340             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
341             result = (1 - ar_mask) * result + ar_mask * self.t_nul
342             masked_inplace_autoregression(
343                 model,
344                 self.batch_size,
345                 result,
346                 ar_mask,
347                 deterministic_synthesis,
348                 progress_bar_desc=None,
349                 device=self.device,
350             )
351
352             result_descr = self.detensorize(result)
353             np = picoclvr.nb_properties(
354                 result_descr,
355                 height=self.height,
356                 width=self.width,
357                 pruner=pruner,
358             )
359             nb_requested_properties, _, nb_missing_properties = zip(*np)
360             acc_nb_requested_properties += nb_requested_properties
361             acc_nb_missing_properties += nb_missing_properties
362             acc_nb_results += len(result_descr)
363
364         nb_requested_properties = sum(acc_nb_requested_properties)
365         nb_missing_properties = sum(acc_nb_missing_properties)
366
367         prefix = "" if pruner is None else "pruned_"
368         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
369         logger(
370             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
371         )
372         logger(
373             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
374         )
375
376         logger(
377             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
378         )
379
380     ######################################################################
381
382     def produce_results(
383         self, n_epoch, model, result_dir, logger, deterministic_synthesis
384     ):
385         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
386
387         if self.pruner_eval is not None:
388             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
389
390         nb_tokens_to_generate = self.height * self.width + 3
391         result_descr = []
392         nb_per_primer = 8
393         primer = []
394
395         for primer_descr in [
396             "red above green <sep> green top <sep> blue right of red",
397             "there is red <sep> there is yellow <sep> there is blue",
398             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
399             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
400         ]:
401             primer += [primer_descr + " <img>"] * nb_per_primer
402
403         result = self.tensorize(primer)
404         fill = result.new_full(
405             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
406         )
407         result = torch.cat((result, fill), 1)
408         ar_mask = (result == self.t_nul).long()
409         masked_inplace_autoregression(
410             model,
411             self.batch_size,
412             result,
413             ar_mask,
414             deterministic_synthesis,
415             device=self.device,
416         )
417         result_descr = self.detensorize(result)
418
419         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
420
421         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
422         acc_nb_results = len(result_descr)
423
424         nb_requested_properties = sum(acc_nb_requested_properties)
425         nb_missing_properties = sum(acc_nb_missing_properties)
426
427         prefix = "demo_"
428         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
429         logger(
430             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
431         )
432         logger(
433             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
434         )
435
436         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
437
438         if img.dim() == 5:
439             if img.size(1) == 1:
440                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
441             else:
442                 img = torch.cat(
443                     [
444                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
445                         for x in img
446                     ],
447                     0,
448                 )
449
450         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
451         torchvision.utils.save_image(
452             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
453         )
454         logger(f"wrote {image_name}")
455
456
457 ######################################################################
458
459
460 class MNIST(Task):
461     def __init__(
462         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
463     ):
464         super().__init__()
465
466         self.nb_train_samples = (nb_train_samples,)
467         self.nb_test_samples = (nb_test_samples,)
468         self.batch_size = batch_size
469         self.device = device
470         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
471         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
472         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
473         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
474
475     def batches(self, split="train", nb_to_use=-1, desc=None):
476         assert split in {"train", "test"}
477         input = self.train_input if split == "train" else self.test_input
478         if nb_to_use > 0:
479             input = input[:nb_to_use]
480         if desc is None:
481             desc = f"epoch-{split}"
482         for batch in tqdm.tqdm(
483             input.split(self.batch_size), dynamic_ncols=True, desc=desc
484         ):
485             yield batch
486
487     def vocabulary_size(self):
488         return 256
489
490     def produce_results(
491         self, n_epoch, model, result_dir, logger, deterministic_synthesis
492     ):
493         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
494         ar_mask = torch.full_like(results, 1)
495         masked_inplace_autoregression(
496             model,
497             self.batch_size,
498             results,
499             ar_mask,
500             deterministic_synthesis,
501             device=self.device,
502         )
503         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
504         torchvision.utils.save_image(
505             1 - results.reshape(-1, 1, 28, 28) / 255.0,
506             image_name,
507             nrow=16,
508             pad_value=0.8,
509         )
510         logger(f"wrote {image_name}")
511
512
513 ######################################################################
514
515 import maze
516
517
518 class Maze(Task):
519     def map2seq(self, *m):
520         return torch.cat([x.flatten(1) for x in m], 1)
521
522     def seq2map(self, s):
523         s = s.reshape(s.size(0), -1, self.height, self.width)
524         return (s[:, k] for k in range(s.size(1)))
525
526     def __init__(
527         self,
528         nb_train_samples,
529         nb_test_samples,
530         batch_size,
531         height,
532         width,
533         nb_walls,
534         device=torch.device("cpu"),
535     ):
536         super().__init__()
537
538         self.batch_size = batch_size
539         self.height = height
540         self.width = width
541         self.device = device
542
543         train_mazes, train_paths, _ = maze.create_maze_data(
544             nb_train_samples,
545             height=height,
546             width=width,
547             nb_walls=nb_walls,
548             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
549         )
550         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
551
552         test_mazes, test_paths, _ = maze.create_maze_data(
553             nb_test_samples,
554             height=height,
555             width=width,
556             nb_walls=nb_walls,
557             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
558         )
559         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
560
561         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
562
563     def batches(self, split="train", nb_to_use=-1, desc=None):
564         assert split in {"train", "test"}
565         input = self.train_input if split == "train" else self.test_input
566         if nb_to_use > 0:
567             input = input[:nb_to_use]
568         if desc is None:
569             desc = f"epoch-{split}"
570         for batch in tqdm.tqdm(
571             input.split(self.batch_size), dynamic_ncols=True, desc=desc
572         ):
573             yield batch
574
575     def vocabulary_size(self):
576         return self.nb_codes
577
578     def compute_error(
579         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
580     ):
581         nb_total, nb_correct = 0, 0
582         count = torch.zeros(
583             self.width * self.height,
584             self.width * self.height,
585             device=self.device,
586             dtype=torch.int64,
587         )
588
589         for input in self.batches(split, nb_to_use):
590             result = input.clone()
591             ar_mask = result.new_zeros(result.size())
592             ar_mask[:, self.height * self.width :] = 1
593             result *= 1 - ar_mask
594             masked_inplace_autoregression(
595                 model,
596                 self.batch_size,
597                 result,
598                 ar_mask,
599                 deterministic_synthesis,
600                 progress_bar_desc=None,
601                 device=self.device,
602             )
603             mazes, paths = self.seq2map(result)
604             path_correctness = maze.path_correctness(mazes, paths)
605             nb_correct += path_correctness.long().sum()
606             nb_total += mazes.size(0)
607
608             optimal_path_lengths = (
609                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
610             )
611             predicted_path_lengths = (
612                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
613             )
614             optimal_path_lengths = optimal_path_lengths[path_correctness]
615             predicted_path_lengths = predicted_path_lengths[path_correctness]
616             count[optimal_path_lengths, predicted_path_lengths] += 1
617
618         if count.max() == 0:
619             count = None
620         else:
621             count = count[
622                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
623             ]
624
625         return nb_total, nb_correct, count
626
627     def produce_results(
628         self, n_epoch, model, result_dir, logger, deterministic_synthesis
629     ):
630         train_nb_total, train_nb_correct, count = self.compute_error(
631             model,
632             "train",
633             nb_to_use=1000,
634             deterministic_synthesis=deterministic_synthesis,
635         )
636         logger(
637             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
638         )
639
640         test_nb_total, test_nb_correct, count = self.compute_error(
641             model,
642             "test",
643             nb_to_use=1000,
644             deterministic_synthesis=deterministic_synthesis,
645         )
646         logger(
647             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
648         )
649
650         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
651
652         if count is not None:
653             proportion_optimal = count.diagonal().sum().float() / count.sum()
654             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
655             with open(
656                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
657             ) as f:
658                 for i in range(count.size(0)):
659                     for j in range(count.size(1)):
660                         eol = " " if j < count.size(1) - 1 else "\n"
661                         f.write(f"{count[i,j]}{eol}")
662
663         input = self.test_input[:48]
664         result = input.clone()
665         ar_mask = result.new_zeros(result.size())
666         ar_mask[:, self.height * self.width :] = 1
667         result *= 1 - ar_mask
668         masked_inplace_autoregression(
669             model,
670             self.batch_size,
671             result,
672             ar_mask,
673             deterministic_synthesis,
674             device=self.device,
675         )
676
677         mazes, paths = self.seq2map(input)
678         _, predicted_paths = self.seq2map(result)
679
680         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
681         maze.save_image(
682             filename,
683             mazes=mazes,
684             target_paths=paths,
685             predicted_paths=predicted_paths,
686             path_correct=maze.path_correctness(mazes, predicted_paths),
687             path_optimal=maze.path_optimality(paths, predicted_paths),
688         )
689         logger(f"wrote {filename}")
690
691
692 ######################################################################
693
694
695 import snake
696
697
698 class Snake(Task):
699     def __init__(
700         self,
701         nb_train_samples,
702         nb_test_samples,
703         batch_size,
704         height,
705         width,
706         nb_colors,
707         length,
708         prompt_length,
709         device=torch.device("cpu"),
710     ):
711         super().__init__()
712
713         self.batch_size = batch_size
714         self.height = height
715         self.width = width
716         self.device = device
717         self.prompt_length = prompt_length
718
719         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
720             nb_train_samples,
721             height,
722             width,
723             nb_colors,
724             length,
725             prompt_length,
726             self.device,
727         )
728         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
729             nb_test_samples,
730             height,
731             width,
732             nb_colors,
733             length,
734             prompt_length,
735             self.device,
736         )
737
738         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
739
740     def batches(self, split="train", nb_to_use=-1, desc=None):
741         assert split in {"train", "test"}
742         input = self.train_input if split == "train" else self.test_input
743         if nb_to_use > 0:
744             input = input[:nb_to_use]
745         if desc is None:
746             desc = f"epoch-{split}"
747         for batch in tqdm.tqdm(
748             input.split(self.batch_size), dynamic_ncols=True, desc=desc
749         ):
750             yield batch
751
752     def vocabulary_size(self):
753         return self.nb_codes
754
755     def produce_results(
756         self, n_epoch, model, result_dir, logger, deterministic_synthesis
757     ):
758         def compute_nb_correct(input, prior_visits):
759             result = input.clone()
760             i = torch.arange(result.size(1), device=result.device)[None, :]
761             ar_mask = (
762                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
763                 .long()
764                 .expand_as(result)
765             )
766             result *= 1 - ar_mask
767
768             masked_inplace_autoregression(
769                 model,
770                 self.batch_size,
771                 result,
772                 ar_mask,
773                 deterministic_synthesis,
774                 device=self.device,
775             )
776
777             nb_total = ((prior_visits > 0) * ar_mask).sum()
778
779             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
780
781             return nb_total, nb_correct
782
783         test_nb_total, test_nb_correct = compute_nb_correct(
784             self.test_input[:1000], self.test_prior_visits[:1000]
785         )
786
787         logger(
788             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
789         )
790
791         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
792
793
794 ######################################################################
795
796
797 import stack
798
799
800 class Stack(Task):
801     def __init__(
802         self,
803         nb_train_samples,
804         nb_test_samples,
805         batch_size,
806         logger,
807         nb_steps,
808         nb_stacks,
809         nb_digits,
810         fraction_values_for_train=None,
811         device=torch.device("cpu"),
812     ):
813         super().__init__()
814
815         self.batch_size = batch_size
816         self.nb_steps = nb_steps
817         self.nb_stacks = nb_stacks
818         self.nb_digits = nb_digits
819         self.device = device
820
821         if fraction_values_for_train is None:
822             values_for_train = None
823             values_for_test = None
824         else:
825             all = torch.randperm(10**nb_digits)
826             nb_for_train = int(all.size(0) * fraction_values_for_train)
827             values_for_train = all[:nb_for_train]
828             values_for_test = all[nb_for_train:]
829
830         self.train_input, self.train_stack_counts = stack.generate_sequences(
831             nb_train_samples,
832             nb_steps,
833             nb_stacks,
834             nb_digits,
835             values_for_train,
836             self.device,
837         )
838
839         self.test_input, self.test_stack_counts = stack.generate_sequences(
840             nb_test_samples,
841             nb_steps,
842             nb_stacks,
843             nb_digits,
844             values_for_test,
845             self.device,
846         )
847
848         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
849         counts = self.test_stack_counts.flatten()[i.flatten()]
850         counts = F.one_hot(counts).sum(0)
851         logger(f"test_pop_stack_counts {counts}")
852
853         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
854
855     def batches(self, split="train", nb_to_use=-1, desc=None):
856         assert split in {"train", "test"}
857         input = self.train_input if split == "train" else self.test_input
858         if nb_to_use > 0:
859             input = input[:nb_to_use]
860         if desc is None:
861             desc = f"epoch-{split}"
862         for batch in tqdm.tqdm(
863             input.split(self.batch_size), dynamic_ncols=True, desc=desc
864         ):
865             yield batch
866
867     def vocabulary_size(self):
868         return self.nb_codes
869
870     def produce_results(
871         self, n_epoch, model, result_dir, logger, deterministic_synthesis
872     ):
873         def compute_nb_correct(input):
874             result = input.clone()
875             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
876             ar_mask = (result != input).long()
877             masked_inplace_autoregression(
878                 model,
879                 self.batch_size,
880                 result,
881                 ar_mask,
882                 deterministic_synthesis,
883                 device=self.device,
884             )
885
886             errors = ((result != input).long() * ar_mask).reshape(
887                 -1, 1 + self.nb_digits
888             )
889             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
890
891             nb_total = ar_mask.max(1).values.sum()
892             nb_correct = nb_total - errors.max(1).values.sum()
893
894             return nb_total, nb_correct
895
896         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
897
898         logger(
899             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
900         )
901
902         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
903
904         ##############################################################
905         # Log a few generated sequences
906         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
907         result = input.clone()
908         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
909         ar_mask = (result != input).long()
910
911         # for n in range(result.size(0)):
912         # logger(
913         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
914         # )
915
916         masked_inplace_autoregression(
917             model,
918             self.batch_size,
919             result,
920             ar_mask,
921             deterministic_synthesis,
922             device=self.device,
923         )
924
925         for n in range(result.size(0)):
926             logger(
927                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
928             )
929         ##############################################################
930
931
932 ######################################################################
933
934 import rpl
935
936
937 class RPL(Task):
938     def tensorize(self, sequences):
939         len_max = max([len(x) for x in sequences])
940         return torch.cat(
941             [
942                 torch.tensor(
943                     [
944                         [
945                             self.token2id[str(c)]
946                             for c in s + ["<nul>"] * (len_max - len(s))
947                         ]
948                         for s in sequences
949                     ]
950                 )
951             ],
952             0,
953         )
954
955     def seq2str(self, seq):
956         return " ".join([self.id2token[i] for i in seq])
957
958     def __init__(
959         self,
960         nb_train_samples,
961         nb_test_samples,
962         batch_size,
963         nb_starting_values=3,
964         max_input=9,
965         prog_len=6,
966         nb_runs=5,
967         no_prog=False,
968         logger=None,
969         device=torch.device("cpu"),
970     ):
971         super().__init__()
972
973         self.batch_size = batch_size
974         self.device = device
975         self.no_prog = no_prog
976
977         train_sequences = [
978             rpl.generate(
979                 nb_starting_values=nb_starting_values,
980                 nb_result_values_max=4 * nb_starting_values,
981                 max_input=max_input,
982                 prog_len=prog_len,
983                 nb_runs=nb_runs,
984             )
985             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
986         ]
987
988         test_sequences = [
989             rpl.generate(
990                 nb_starting_values=nb_starting_values,
991                 nb_result_values_max=4 * nb_starting_values,
992                 max_input=max_input,
993                 prog_len=prog_len,
994                 nb_runs=nb_runs,
995             )
996             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
997         ]
998
999         symbols = list(
1000             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1001         )
1002         val_max = max([x if type(x) is int else 0 for x in symbols])
1003         symbols = list(filter(lambda x: type(x) is str, symbols))
1004         symbols.sort()
1005         symbols += [str(n) for n in range(val_max + 1)]
1006         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1007         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1008
1009         self.t_nul = self.token2id["<nul>"]
1010         self.t_input = self.token2id["<in>"]
1011         self.t_output = self.token2id["<out>"]
1012         self.t_prog = self.token2id["<prg>"]
1013         self.t_end = self.token2id["<end>"]
1014
1015         self.train_input = self.tensorize(train_sequences)
1016         self.test_input = self.tensorize(test_sequences)
1017
1018         if no_prog:
1019             # Excise the program from every train and test example
1020             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1021                 None, :
1022             ]
1023             p = (
1024                 ((self.train_input == self.t_prog).long() * k)
1025                 .max(1, keepdim=True)
1026                 .values
1027             )
1028             self.train_input = (
1029                 self.train_input * (k <= p).long()
1030                 + self.t_end * (k == p + 1).long()
1031                 + self.t_nul * (k > p + 1).long()
1032             )
1033             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1034                 None, :
1035             ]
1036             p = (
1037                 ((self.test_input == self.t_prog).long() * k)
1038                 .max(1, keepdim=True)
1039                 .values
1040             )
1041             self.test_input = (
1042                 self.test_input * (k <= p).long()
1043                 + self.t_end * (k == p + 1).long()
1044                 + self.t_nul * (k > p + 1).long()
1045             )
1046
1047         if logger is not None:
1048             logger(f"value_max {val_max}")
1049             for x in self.train_input[:25]:
1050                 end = (x != self.t_nul).nonzero().max().item() + 1
1051                 seq = [self.id2token[i.item()] for i in x[:end]]
1052                 s = " ".join(seq)
1053                 logger(f"example_seq {s}")
1054
1055         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1056
1057     def batches(self, split="train", nb_to_use=-1, desc=None):
1058         assert split in {"train", "test"}
1059         input = self.train_input if split == "train" else self.test_input
1060         if nb_to_use > 0:
1061             input = input[:nb_to_use]
1062         if desc is None:
1063             desc = f"epoch-{split}"
1064         for batch in tqdm.tqdm(
1065             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1066         ):
1067             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1068             batch = batch[:, :last].to(self.device)
1069             yield batch
1070
1071     def vocabulary_size(self):
1072         return self.nb_codes
1073
1074     def produce_results(
1075         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1076     ):
1077         # --------------------------------------------------------------------
1078         def compute_nb_errors_prog(input, nb_to_log=0):
1079             result = input.clone()
1080             s = (result == self.t_prog).long()
1081             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1082             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1083
1084             masked_inplace_autoregression(
1085                 model,
1086                 self.batch_size,
1087                 result,
1088                 ar_mask,
1089                 deterministic_synthesis,
1090                 device=self.device,
1091             )
1092
1093             sum_nb_total, sum_nb_errors = 0, 0
1094             for one_input, one_result in zip(input, result):
1095                 seq = [self.id2token[i.item()] for i in one_result]
1096                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1097                 sum_nb_total += 1
1098                 sum_nb_errors += 0 if nb_errors == 0 else 1
1099                 if nb_to_log > 0:
1100                     gt_seq = [self.id2token[i.item()] for i in one_input]
1101                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1102                     gt_prog = " ".join([str(x) for x in gt_prog])
1103                     prog = " ".join([str(x) for x in prog])
1104                     comment = "*" if nb_errors == 0 else "-"
1105                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1106                     for start_stack, target_stack, result_stack, correct in stacks:
1107                         comment = "*" if correct else "-"
1108                         start_stack = " ".join([str(x) for x in start_stack])
1109                         target_stack = " ".join([str(x) for x in target_stack])
1110                         result_stack = " ".join([str(x) for x in result_stack])
1111                         logger(
1112                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1113                         )
1114                     nb_to_log -= 1
1115
1116             return sum_nb_total, sum_nb_errors
1117
1118         # --------------------------------------------------------------------
1119         def compute_nb_errors_output(input, nb_to_log=0):
1120             result = input.clone()
1121             k = torch.arange(result.size(1), device=result.device)[None, :]
1122             last_output_idx = (
1123                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1124             )
1125             first_prog_idx = (
1126                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1127             )
1128             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1129             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1130
1131             masked_inplace_autoregression(
1132                 model,
1133                 self.batch_size,
1134                 result,
1135                 ar_mask,
1136                 deterministic_synthesis,
1137                 device=self.device,
1138             )
1139
1140             sum_nb_total, sum_nb_errors = 0, 0
1141             for one_input, one_result, i, j in zip(
1142                 input, result, last_output_idx, first_prog_idx
1143             ):
1144                 seq = [self.id2token[i.item()] for i in one_result]
1145                 sum_nb_total += 1
1146                 correct = (one_input - one_result).abs().max() == 0
1147                 sum_nb_errors += 0 if correct else 1
1148                 if nb_to_log > 0:
1149                     result_stack = [
1150                         self.id2token[i.item()] for i in one_result[i : j + 1]
1151                     ]
1152                     target_stack = [
1153                         self.id2token[i.item()] for i in one_input[i : j + 1]
1154                     ]
1155                     comment = "*" if correct else "-"
1156                     result_stack = " ".join([str(x) for x in result_stack])
1157                     target_stack = " ".join([str(x) for x in target_stack])
1158                     logger(
1159                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1160                     )
1161                     nb_to_log -= 1
1162
1163             return sum_nb_total, sum_nb_errors
1164
1165         # --------------------------------------------------------------------
1166
1167         if not self.no_prog:
1168             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1169                 self.test_input[:1000].to(self.device), nb_to_log=10
1170             )
1171
1172             logger(
1173                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1174             )
1175
1176             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1177
1178         test_nb_total, test_nb_errors = compute_nb_errors_output(
1179             self.test_input[:1000].to(self.device), nb_to_log=10
1180         )
1181
1182         logger(
1183             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1184         )
1185
1186         if save_attention_image is None:
1187             logger("no save_attention_image (is pycairo installed?)")
1188         else:
1189             ns = torch.randint(self.test_input.size(0), (1,)).item()
1190             input = self.test_input[ns : ns + 1].clone()
1191             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1192             input = input[:, :last].to(self.device)
1193
1194             with torch.autograd.no_grad():
1195                 t = model.training
1196                 model.eval()
1197                 model.record_attention(True)
1198                 model(BracketedSequence(input))
1199                 model.train(t)
1200                 ram = model.retrieve_attention()
1201                 model.record_attention(False)
1202
1203             tokens_output = [self.id2token[i.item()] for i in input[0]]
1204             tokens_input = ["n/a"] + tokens_output[:-1]
1205             for n_head in range(ram[0].size(1)):
1206                 filename = os.path.join(
1207                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1208                 )
1209                 attention_matrices = [m[0, n_head] for m in ram]
1210                 save_attention_image(
1211                     filename,
1212                     tokens_input,
1213                     tokens_output,
1214                     attention_matrices,
1215                     k_top=10,
1216                     # min_total_attention=0.9,
1217                     token_gap=12,
1218                     layer_gap=50,
1219                 )
1220                 logger(f"wrote {filename}")
1221
1222
1223 ######################################################################
1224
1225
1226 import expr
1227
1228
1229 class Expr(Task):
1230     def tensorize(self, sequences):
1231         len_max = max([len(x) for x in sequences])
1232         return torch.cat(
1233             [
1234                 torch.tensor(
1235                     [
1236                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1237                         for s in sequences
1238                     ]
1239                 )
1240             ],
1241             0,
1242         ).to(self.device)
1243
1244     def __init__(
1245         self,
1246         nb_train_samples,
1247         nb_test_samples,
1248         nb_variables,
1249         sequence_length,
1250         operand_max,
1251         result_max,
1252         batch_size,
1253         device=torch.device("cpu"),
1254     ):
1255         super().__init__()
1256
1257         self.batch_size = batch_size
1258         self.device = device
1259
1260         train_sequences = expr.generate_sequences(
1261             nb_train_samples,
1262             nb_variables=nb_variables,
1263             length=sequence_length,
1264             operand_max=operand_max,
1265             result_max=result_max,
1266         )
1267
1268         test_sequences = expr.generate_sequences(
1269             nb_test_samples,
1270             nb_variables=nb_variables,
1271             length=sequence_length,
1272             operand_max=operand_max,
1273             result_max=result_max,
1274         )
1275
1276         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1277         symbols.sort()
1278
1279         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1280         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1281
1282         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1283
1284         self.train_input = self.tensorize(train_sequences)
1285         self.test_input = self.tensorize(test_sequences)
1286
1287         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1288
1289     def batches(self, split="train", nb_to_use=-1, desc=None):
1290         assert split in {"train", "test"}
1291         input = self.train_input if split == "train" else self.test_input
1292         if nb_to_use > 0:
1293             input = input[:nb_to_use]
1294         if desc is None:
1295             desc = f"epoch-{split}"
1296         for batch in tqdm.tqdm(
1297             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1298         ):
1299             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1300             batch = batch[:, :last]
1301             yield batch
1302
1303     def vocabulary_size(self):
1304         return self.nb_codes
1305
1306     def seq2str(self, s):
1307         return "".join([self.id2char[k.item()] for k in s])
1308
1309     def produce_results(
1310         self,
1311         n_epoch,
1312         model,
1313         result_dir,
1314         logger,
1315         deterministic_synthesis,
1316         input_file=None,
1317     ):
1318         def compute_nb_correct(input):
1319             result = input.clone()
1320             s = (result == self.space).long()
1321             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1322             result = (1 - ar_mask) * result + ar_mask * self.filler
1323             masked_inplace_autoregression(
1324                 model,
1325                 self.batch_size,
1326                 result,
1327                 ar_mask,
1328                 deterministic_synthesis,
1329                 device=self.device,
1330             )
1331
1332             nb_total = input.size(0)
1333             nb_correct = (input == result).long().min(1).values.sum()
1334
1335             #######################################################################
1336             # Comput predicted vs. true variable values
1337
1338             nb_delta = torch.zeros(5, dtype=torch.int64)
1339             nb_missed = 0
1340
1341             values_input = expr.extract_results([self.seq2str(s) for s in input])
1342             values_result = expr.extract_results([self.seq2str(s) for s in result])
1343
1344             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1345
1346             with open(filename, "w") as f:
1347                 for i, r in zip(values_input, values_result):
1348                     for n, vi in i.items():
1349                         vr = r.get(n)
1350                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1351
1352                         if vr is None or vr < 0:
1353                             nb_missed += 1
1354                         else:
1355                             d = abs(vr - vi)
1356                             if d >= nb_delta.size(0):
1357                                 nb_missed += 1
1358                             else:
1359                                 nb_delta[d] += 1
1360
1361             ######################################################################
1362
1363             return nb_total, nb_correct, nb_delta, nb_missed
1364
1365         (
1366             test_nb_total,
1367             test_nb_correct,
1368             test_nb_delta,
1369             test_nb_missed,
1370         ) = compute_nb_correct(self.test_input[:10000])
1371
1372         logger(
1373             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1374         )
1375
1376         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1377
1378         nb_total = test_nb_delta.sum() + test_nb_missed
1379         for d in range(test_nb_delta.size(0)):
1380             logger(
1381                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1382             )
1383         logger(
1384             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1385         )
1386
1387         ##############################################################
1388         # Log a few generated sequences
1389         if input_file is None:
1390             input = self.test_input[:10]
1391         else:
1392             with open(input_file, "r") as f:
1393                 sequences = [e.strip() for e in f.readlines()]
1394                 sequences = [s + " " + "#" * 50 for s in sequences]
1395                 input = self.tensorize(sequences)
1396
1397         result = input.clone()
1398         s = (result == self.space).long()
1399         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1400         result = (1 - ar_mask) * result + ar_mask * self.filler
1401
1402         for n in range(result.size(0)):
1403             logger(f"test_before {self.seq2str(result[n])}")
1404
1405         masked_inplace_autoregression(
1406             model,
1407             self.batch_size,
1408             result,
1409             ar_mask,
1410             deterministic_synthesis,
1411             device=self.device,
1412         )
1413
1414         correct = (1 - ar_mask) * self.space + ar_mask * input
1415         for n in range(result.size(0)):
1416             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1417             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1418             logger(f"truth       {self.seq2str(correct[n])}")
1419         ##############################################################
1420
1421
1422 ######################################################################
1423
1424 import grid
1425
1426
1427 class Grid(Task):
1428     # Make a tensor from a list of strings
1429     def tensorize(self, descr):
1430         token_descr = [s.strip().split(" ") for s in descr]
1431         l = max([len(s) for s in token_descr])
1432         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1433         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1434         return torch.tensor(id_descr, device=self.device)
1435
1436     # Make a list of strings from a tensor
1437     def detensorize(self, x):
1438         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1439
1440     # trim all the tensors in the tuple z to remove as much token from
1441     # left and right in the first tensor. If z is a tuple, all its
1442     # elements are trimed according to the triming for the first
1443     def trim(self, z, token="#"):
1444         n = self.token2id[token]
1445         if type(z) == tuple:
1446             x = z[0]
1447             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1448             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1449             return tuple([t[:, a:b] for t in z])
1450         else:
1451             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1452             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1453             return z[:, a:b]
1454
1455     ######################
1456
1457     def __init__(
1458         self,
1459         nb_train_samples,
1460         nb_test_samples,
1461         batch_size,
1462         size,
1463         logger=None,
1464         device=torch.device("cpu"),
1465     ):
1466         super().__init__()
1467
1468         self.device = device
1469         self.batch_size = batch_size
1470         self.grid_factory = grid.GridFactory(size=size)
1471
1472         if logger is not None:
1473             logger(
1474                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1475             )
1476
1477         self.train_descr = self.grid_factory.generate_samples(
1478             nb_train_samples, lambda r: tqdm.tqdm(r)
1479         )
1480         self.test_descr = self.grid_factory.generate_samples(
1481             nb_test_samples, lambda r: tqdm.tqdm(r)
1482         )
1483
1484         # Build the tokenizer
1485         tokens = set()
1486         for d in [self.train_descr, self.test_descr]:
1487             for s in d:
1488                 for t in s.strip().split(" "):
1489                     tokens.add(t)
1490         # make this set a sorted list to get the same tensors given
1491         # the same descr
1492         tokens = list(tokens)
1493         tokens.sort()
1494         tokens = ["#"] + tokens
1495         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1496         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1497         self.t_nul = self.token2id["#"]
1498         self.t_true = self.token2id["true"]
1499         self.t_false = self.token2id["false"]
1500
1501         # Tokenize the train and test sets
1502         self.train_input = self.tensorize(self.train_descr)
1503         self.test_input = self.tensorize(self.test_descr)
1504
1505     def batches(self, split="train"):
1506         assert split in {"train", "test"}
1507         input = self.train_input if split == "train" else self.test_input
1508         for batch in tqdm.tqdm(
1509             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1510         ):
1511             yield self.trim(batch)
1512
1513     def vocabulary_size(self):
1514         return len(self.token2id)
1515
1516     def produce_results(
1517         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1518     ):
1519         correct = self.test_input[:1000]
1520         result = correct.clone()
1521         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1522         result *= 1 - ar_mask
1523
1524         for e in self.detensorize(result[:10]):
1525             logger(f"test_before {e}")
1526
1527         masked_inplace_autoregression(
1528             model,
1529             self.batch_size,
1530             result,
1531             ar_mask,
1532             deterministic_synthesis,
1533             device=self.device,
1534         )
1535
1536         for e in self.detensorize(result[:10]):
1537             logger(f"test_after {e}")
1538
1539         nb_total = ar_mask.sum().item()
1540         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1541
1542         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1543         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1544
1545
1546 ######################################################################
1547
1548 import world
1549
1550
1551 class World(Task):
1552     def __init__(
1553         self,
1554         nb_train_samples,
1555         nb_test_samples,
1556         batch_size,
1557         vqae_nb_epochs,
1558         logger=None,
1559         device=torch.device("cpu"),
1560         device_storage=torch.device("cpu"),
1561     ):
1562         super().__init__()
1563
1564         self.batch_size = batch_size
1565         self.device = device
1566
1567         (
1568             train_frames,
1569             train_action_seq,
1570             test_frames,
1571             test_action_seq,
1572             self.frame2seq,
1573             self.seq2frame,
1574         ) = world.create_data_and_processors(
1575             nb_train_samples,
1576             nb_test_samples,
1577             mode="first_last",
1578             nb_steps=30,
1579             nb_epochs=vqae_nb_epochs,
1580             logger=logger,
1581             device=device,
1582             device_storage=device_storage,
1583         )
1584
1585         train_frame_seq = self.frame2seq(train_frames).to(device_storage)
1586         test_frame_seq = self.frame2seq(test_frames).to(device_storage)
1587
1588         nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
1589         nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
1590
1591         self.len_frame_seq = train_frame_seq.size(1)
1592         self.len_action_seq = train_action_seq.size(1)
1593         self.nb_codes = nb_frame_codes + nb_action_codes
1594
1595         train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
1596
1597         train_action_seq += nb_frame_codes
1598         self.train_input = torch.cat(
1599             (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
1600         )
1601
1602         test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
1603         test_action_seq += nb_frame_codes
1604         self.test_input = torch.cat(
1605             (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
1606         )
1607
1608     def batches(self, split="train", nb_to_use=-1, desc=None):
1609         assert split in {"train", "test"}
1610         input = self.train_input if split == "train" else self.test_input
1611         if nb_to_use > 0:
1612             input = input[:nb_to_use]
1613         if desc is None:
1614             desc = f"epoch-{split}"
1615         for batch in tqdm.tqdm(
1616             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1617         ):
1618             yield batch.to(self.device)
1619
1620     def vocabulary_size(self):
1621         return self.nb_codes
1622
1623     def produce_results(
1624         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1625     ):
1626         k = torch.arange(
1627             2 * self.len_frame_seq + self.len_action_seq, device=self.device
1628         )[None, :]
1629
1630         input = self.test_input[:64].to(self.device)
1631         result = input.clone()
1632
1633         ar_mask = (
1634             (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
1635         )
1636         result *= 1 - ar_mask
1637
1638         masked_inplace_autoregression(
1639             model,
1640             self.batch_size,
1641             result,
1642             ar_mask,
1643             deterministic_synthesis,
1644             device=self.device,
1645         )
1646
1647         seq_start = input[:, : self.len_frame_seq]
1648         seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
1649         seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
1650
1651         result = torch.cat(
1652             (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
1653         )
1654         result = result.reshape(-1, result.size(-1))
1655
1656         frames = self.seq2frame(result)
1657         image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
1658         torchvision.utils.save_image(
1659             frames.float() / (world.Box.nb_rgb_levels - 1),
1660             image_name,
1661             nrow=12,
1662             padding=1,
1663             pad_value=0.0,
1664         )
1665         logger(f"wrote {image_name}")
1666
1667
1668 ######################################################################