Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train"):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 class TaskFromFile(Task):
74     def tensorize(self, pairs, shuffle):
75         len_max = max([len(x[0]) for x in pairs])
76
77         input = torch.cat(
78             [
79                 torch.tensor(
80                     [
81                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
82                         for s in pairs
83                     ]
84                 )
85             ],
86             0,
87         ).to("cpu")
88
89         pred_mask = torch.cat(
90             [
91                 torch.tensor(
92                     [
93                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
94                         for s in pairs
95                     ]
96                 )
97             ],
98             0,
99         ).to("cpu")
100
101         if shuffle:
102             i = torch.randperm(input.size(0))
103             input = input[i].contiguous()
104             pred_mask = pred_mask[i].contiguous()
105
106         return input, pred_mask
107
108     # trim all the tensors in the tuple z to remove as much token from
109     # left and right in the first tensor. If z is a tuple, all its
110     # elements are trimed according to the triming for the first
111     def trim(self, z, token="#"):
112         n = self.char2id[token]
113         if type(z) == tuple:
114             x = z[0]
115             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
116             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
117             return tuple([t[:, a:b] for t in z])
118         else:
119             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
120             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
121             return z[:, a:b]
122
123     def __init__(
124         self,
125         train_filename,
126         test_filename,
127         nb_train_samples,
128         nb_test_samples,
129         batch_size,
130         shuffle=False,
131         device=torch.device("cpu"),
132     ):
133         self.batch_size = batch_size
134         self.device = device
135
136         def read_file(filename, nb=-1):
137             pairs = []
138             with open(filename, "r") as f:
139                 while True:
140                     sequence = f.readline().strip()
141                     if not sequence:
142                         break
143                     pred_mask = f.readline().strip()
144                     assert len(sequence) == len(pred_mask)
145                     assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
146                     pairs.append((sequence, pred_mask))
147                     if len(pairs) == nb:
148                         break
149
150             if nb > 0:
151                 pairs = pairs[:nb]
152                 assert len(pairs) == nb
153
154             return pairs
155
156         train_pairs = read_file(train_filename, nb_train_samples)
157         test_pairs = read_file(test_filename, nb_test_samples)
158
159         symbols = ["#"] + list(
160             set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
161         )
162         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
163         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
164
165         self.train_input, self.train_pred_masks = self.tensorize(
166             train_pairs, shuffle=shuffle
167         )
168         self.test_input, self.test_pred_masks = self.tensorize(
169             test_pairs, shuffle=shuffle
170         )
171
172     def batches(self, split="train", nb_to_use=-1, desc=None):
173         assert split in {"train", "test"}
174         input = self.train_input if split == "train" else self.test_input
175         if nb_to_use > 0:
176             input = input[:nb_to_use]
177         if desc is None:
178             desc = f"epoch-{split}"
179         for batch in tqdm.tqdm(
180             input.split(self.batch_size), dynamic_ncols=True, desc=desc
181         ):
182             yield self.trim(batch).to(self.device)
183
184     def vocabulary_size(self):
185         return len(self.char2id)
186
187     def tensor2str(self, t):
188         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
189
190     def produce_results(
191         self, n_epoch, model, result_dir, logger, deterministic_synthesis
192     ):
193         correct = self.trim(self.test_input[:1000]).to(self.device)
194         result = correct.clone()
195         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
196         ar_mask = (pred_mask > 0).long()
197         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
198
199         logger(f"----------------------------------------------------------")
200
201         for e in self.tensor2str(result[:50]):
202             logger(f"test_before {e}")
203
204         masked_inplace_autoregression(
205             model,
206             self.batch_size,
207             result,
208             ar_mask,
209             deterministic_synthesis,
210             device=self.device,
211         )
212
213         logger(f"----------------------------------------------------------")
214
215         for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
216             logger(f"test_after  {e}")
217             logger(f"correct     {c}")
218
219         logger(f"----------------------------------------------------------")
220
221         err_mask = (pred_mask == 2).long()
222         nb_total = err_mask.sum().item()
223         nb_correct = ((correct == result).long() * err_mask).sum().item()
224
225         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
226         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
227
228
229 ####################
230
231 import problems
232
233
234 class SandBox(Task):
235     def __init__(
236         self,
237         problem,
238         nb_train_samples,
239         nb_test_samples,
240         batch_size,
241         logger=None,
242         device=torch.device("cpu"),
243         max_nb_codes=1024,
244     ):
245         super().__init__()
246
247         self.batch_size = batch_size
248         self.device = device
249         self.problem = problem
250
251         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
252             nb_train_samples
253         )
254         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
255             nb_test_samples
256         )
257
258         self.train_input, self.train_ar_mask = self.train_input.to(
259             device
260         ), self.train_ar_mask.to(device)
261         self.test_input, self.test_ar_mask = self.test_input.to(
262             device
263         ), self.test_ar_mask.to(device)
264
265         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
266
267         # A bit of paranoia never hurts
268         assert self.nb_codes <= max_nb_codes
269         assert self.train_input.min() >= 0
270         assert self.test_input.min() >= 0
271         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
272             (0,),
273             (1,),
274             (0, 1),
275         }
276         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
277             (0,),
278             (1,),
279             (0, 1),
280         }
281
282         if logger is not None:
283             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
284                 logger(f"train_sequences {self.problem.seq2str(s)}")
285                 a = "".join(["01"[x.item()] for x in a])
286                 logger(f"                {a}")
287
288     def batches(self, split="train", nb_to_use=-1, desc=None):
289         assert split in {"train", "test"}
290         input = self.train_input if split == "train" else self.test_input
291         if nb_to_use > 0:
292             input = input[:nb_to_use]
293         if desc is None:
294             desc = f"epoch-{split}"
295         for batch in tqdm.tqdm(
296             input.split(self.batch_size), dynamic_ncols=True, desc=desc
297         ):
298             yield batch
299
300     def vocabulary_size(self):
301         return self.nb_codes
302
303     def produce_results(
304         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
305     ):
306         def compute_accuracy(input, ar_mask, logger=None):
307             input, ar_mask = input[:nmax], ar_mask[:nmax]
308             result = input.clone() * (1 - ar_mask)
309
310             masked_inplace_autoregression(
311                 model,
312                 self.batch_size,
313                 result,
314                 ar_mask,
315                 deterministic_synthesis,
316                 progress_bar_desc=None,
317                 device=self.device,
318             )
319
320             log_ground_truth = ar_mask.min() == 0
321
322             if logger is not None:
323                 for sp, st in zip(result[:10], input[:10]):
324                     logger(
325                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
326                     )
327                     if log_ground_truth:
328                         logger(
329                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
330                         )
331
332             nb_total, nb_correct = self.problem.compute_nb_correct(
333                 input, ar_mask, result
334             )
335
336             # nb_total = ar_mask.sum().item()
337             # nb_correct = ((result == input).long() * ar_mask).sum().item()
338
339             return nb_total, nb_correct
340
341         train_nb_total, train_nb_correct = compute_accuracy(
342             self.train_input, self.train_ar_mask
343         )
344
345         logger(
346             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
347         )
348
349         test_nb_total, test_nb_correct = compute_accuracy(
350             self.test_input, self.test_ar_mask, logger
351         )
352
353         logger(
354             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
355         )
356
357         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
358
359         if save_attention_image is not None:
360             for k in range(10):
361                 ns = torch.randint(self.test_input.size(0), (1,)).item()
362                 input = self.test_input[ns : ns + 1].clone()
363
364                 with torch.autograd.no_grad():
365                     t = model.training
366                     model.eval()
367                     # model.record_attention(True)
368                     model(BracketedSequence(input))
369                     model.train(t)
370                     # ram = model.retrieve_attention()
371                     # model.record_attention(False)
372
373                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
374                 # tokens_input = ["n/a"] + tokens_output[:-1]
375                 # for n_head in range(ram[0].size(1)):
376                 # filename = os.path.join(
377                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
378                 # )
379                 # attention_matrices = [m[0, n_head] for m in ram]
380                 # save_attention_image(
381                 # filename,
382                 # tokens_input,
383                 # tokens_output,
384                 # attention_matrices,
385                 # k_top=10,
386                 ##min_total_attention=0.9,
387                 # token_gap=12,
388                 # layer_gap=50,
389                 # )
390                 # logger(f"wrote {filename}")
391
392
393 ######################################################################
394
395 import picoclvr
396
397
398 class PicoCLVR(Task):
399     # Make a tensor from a list of strings
400     def tensorize(self, descr):
401         token_descr = [s.strip().split(" ") for s in descr]
402         l = max([len(s) for s in token_descr])
403         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
404         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
405         return torch.tensor(id_descr, device=self.device)
406
407     # Make a list of strings from a tensor
408     def detensorize(self, x):
409         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
410
411     # trim all the tensors in the tuple z to remove as much token from
412     # left and right in the first tensor. If z is a tuple, all its
413     # elements are trimed according to the triming for the first
414     def trim(self, z, token="<nul>"):
415         n = self.token2id[token]
416         if type(z) == tuple:
417             x = z[0]
418             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
419             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
420             return tuple([t[:, a:b] for t in z])
421         else:
422             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
423             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
424             return z[:, a:b]
425
426     ######################
427
428     def __init__(
429         self,
430         nb_train_samples,
431         nb_test_samples,
432         batch_size,
433         height,
434         width,
435         nb_colors=5,
436         logger=None,
437         device=torch.device("cpu"),
438         pruner_train=None,
439         pruner_eval=None,
440     ):
441         super().__init__()
442
443         def generate_descr(nb, cache_suffix, pruner):
444             return picoclvr.generate(
445                 nb,
446                 height=self.height,
447                 width=self.width,
448                 nb_colors=nb_colors,
449                 pruner=pruner,
450             )
451
452         self.height = height
453         self.width = width
454         self.batch_size = batch_size
455         self.device = device
456         self.pruner_train = pruner_train
457         self.pruner_eval = pruner_eval
458
459         if logger is not None:
460             logger(
461                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
462             )
463
464         self.train_descr = generate_descr(
465             nb_train_samples, "train", pruner=self.pruner_train
466         )
467         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
468
469         # Build the tokenizer
470         tokens = {"<nul>", "<img>"}
471         for d in [self.train_descr, self.test_descr]:
472             for s in d:
473                 for t in s.strip().split(" "):
474                     tokens.add(t)
475         # make this set a sorted list to get the same tensors given
476         # the same descr
477         tokens = list(tokens)
478         tokens.sort()
479         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
480         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
481         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
482
483         # Tokenize the train and test sets
484         self.train_input = self.tensorize(self.train_descr)
485         self.test_input = self.tensorize(self.test_descr)
486
487     def batches(self, split="train"):
488         assert split in {"train", "test"}
489         input = self.train_input if split == "train" else self.test_input
490         for batch in tqdm.tqdm(
491             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
492         ):
493             yield self.trim(batch)
494
495     def vocabulary_size(self):
496         return len(self.token2id)
497
498     def compute_missing_properties(
499         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
500     ):
501         acc_nb_requested_properties = []
502         acc_nb_missing_properties = []
503         acc_nb_results = 0
504
505         for input in tqdm.tqdm(
506             self.test_input.split(self.batch_size),
507             dynamic_ncols=True,
508             desc=f"test-properties",
509         ):
510             result = input.clone()
511             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
512             result = (1 - ar_mask) * result + ar_mask * self.t_nul
513             masked_inplace_autoregression(
514                 model,
515                 self.batch_size,
516                 result,
517                 ar_mask,
518                 deterministic_synthesis,
519                 progress_bar_desc=None,
520                 device=self.device,
521             )
522
523             result_descr = self.detensorize(result)
524             np = picoclvr.nb_properties(
525                 result_descr,
526                 height=self.height,
527                 width=self.width,
528                 pruner=pruner,
529             )
530             nb_requested_properties, _, nb_missing_properties = zip(*np)
531             acc_nb_requested_properties += nb_requested_properties
532             acc_nb_missing_properties += nb_missing_properties
533             acc_nb_results += len(result_descr)
534
535         nb_requested_properties = sum(acc_nb_requested_properties)
536         nb_missing_properties = sum(acc_nb_missing_properties)
537
538         prefix = "" if pruner is None else "pruned_"
539         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
540         logger(
541             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
542         )
543         logger(
544             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
545         )
546
547         logger(
548             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
549         )
550
551     ######################################################################
552
553     def produce_results(
554         self, n_epoch, model, result_dir, logger, deterministic_synthesis
555     ):
556         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
557
558         if self.pruner_eval is not None:
559             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
560
561         nb_tokens_to_generate = self.height * self.width + 3
562         result_descr = []
563         nb_per_primer = 8
564         primer = []
565
566         for primer_descr in [
567             "red above green <sep> green top <sep> blue right of red",
568             "there is red <sep> there is yellow <sep> there is blue",
569             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
570             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
571         ]:
572             primer += [primer_descr + " <img>"] * nb_per_primer
573
574         result = self.tensorize(primer)
575         fill = result.new_full(
576             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
577         )
578         result = torch.cat((result, fill), 1)
579         ar_mask = (result == self.t_nul).long()
580         masked_inplace_autoregression(
581             model,
582             self.batch_size,
583             result,
584             ar_mask,
585             deterministic_synthesis,
586             device=self.device,
587         )
588         result_descr = self.detensorize(result)
589
590         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
591
592         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
593         acc_nb_results = len(result_descr)
594
595         nb_requested_properties = sum(acc_nb_requested_properties)
596         nb_missing_properties = sum(acc_nb_missing_properties)
597
598         prefix = "demo_"
599         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
600         logger(
601             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
602         )
603         logger(
604             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
605         )
606
607         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
608
609         if img.dim() == 5:
610             if img.size(1) == 1:
611                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
612             else:
613                 img = torch.cat(
614                     [
615                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
616                         for x in img
617                     ],
618                     0,
619                 )
620
621         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
622         torchvision.utils.save_image(
623             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
624         )
625         logger(f"wrote {image_name}")
626
627
628 ######################################################################
629
630
631 class MNIST(Task):
632     def __init__(
633         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
634     ):
635         super().__init__()
636
637         self.nb_train_samples = (nb_train_samples,)
638         self.nb_test_samples = (nb_test_samples,)
639         self.batch_size = batch_size
640         self.device = device
641         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
642         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
643         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
644         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
645
646     def batches(self, split="train", nb_to_use=-1, desc=None):
647         assert split in {"train", "test"}
648         input = self.train_input if split == "train" else self.test_input
649         if nb_to_use > 0:
650             input = input[:nb_to_use]
651         if desc is None:
652             desc = f"epoch-{split}"
653         for batch in tqdm.tqdm(
654             input.split(self.batch_size), dynamic_ncols=True, desc=desc
655         ):
656             yield batch
657
658     def vocabulary_size(self):
659         return 256
660
661     def produce_results(
662         self, n_epoch, model, result_dir, logger, deterministic_synthesis
663     ):
664         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
665         ar_mask = torch.full_like(results, 1)
666         masked_inplace_autoregression(
667             model,
668             self.batch_size,
669             results,
670             ar_mask,
671             deterministic_synthesis,
672             device=self.device,
673         )
674         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
675         torchvision.utils.save_image(
676             1 - results.reshape(-1, 1, 28, 28) / 255.0,
677             image_name,
678             nrow=16,
679             pad_value=0.8,
680         )
681         logger(f"wrote {image_name}")
682
683
684 ######################################################################
685
686 import maze
687
688
689 class Maze(Task):
690     def map2seq(self, *m):
691         return torch.cat([x.flatten(1) for x in m], 1)
692
693     def seq2map(self, s):
694         s = s.reshape(s.size(0), -1, self.height, self.width)
695         return (s[:, k] for k in range(s.size(1)))
696
697     def __init__(
698         self,
699         nb_train_samples,
700         nb_test_samples,
701         batch_size,
702         height,
703         width,
704         nb_walls,
705         device=torch.device("cpu"),
706     ):
707         super().__init__()
708
709         self.batch_size = batch_size
710         self.height = height
711         self.width = width
712         self.device = device
713
714         train_mazes, train_paths, _ = maze.create_maze_data(
715             nb_train_samples,
716             height=height,
717             width=width,
718             nb_walls=nb_walls,
719             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
720         )
721         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
722
723         test_mazes, test_paths, _ = maze.create_maze_data(
724             nb_test_samples,
725             height=height,
726             width=width,
727             nb_walls=nb_walls,
728             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
729         )
730         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
731
732         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
733
734     def batches(self, split="train", nb_to_use=-1, desc=None):
735         assert split in {"train", "test"}
736         input = self.train_input if split == "train" else self.test_input
737         if nb_to_use > 0:
738             input = input[:nb_to_use]
739         if desc is None:
740             desc = f"epoch-{split}"
741         for batch in tqdm.tqdm(
742             input.split(self.batch_size), dynamic_ncols=True, desc=desc
743         ):
744             yield batch
745
746     def vocabulary_size(self):
747         return self.nb_codes
748
749     def compute_error(
750         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
751     ):
752         nb_total, nb_correct = 0, 0
753         count = torch.zeros(
754             self.width * self.height,
755             self.width * self.height,
756             device=self.device,
757             dtype=torch.int64,
758         )
759
760         for input in self.batches(split, nb_to_use):
761             result = input.clone()
762             ar_mask = result.new_zeros(result.size())
763             ar_mask[:, self.height * self.width :] = 1
764             result *= 1 - ar_mask
765             masked_inplace_autoregression(
766                 model,
767                 self.batch_size,
768                 result,
769                 ar_mask,
770                 deterministic_synthesis,
771                 progress_bar_desc=None,
772                 device=self.device,
773             )
774             mazes, paths = self.seq2map(result)
775             path_correctness = maze.path_correctness(mazes, paths)
776             nb_correct += path_correctness.long().sum()
777             nb_total += mazes.size(0)
778
779             optimal_path_lengths = (
780                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
781             )
782             predicted_path_lengths = (
783                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
784             )
785             optimal_path_lengths = optimal_path_lengths[path_correctness]
786             predicted_path_lengths = predicted_path_lengths[path_correctness]
787             count[optimal_path_lengths, predicted_path_lengths] += 1
788
789         if count.max() == 0:
790             count = None
791         else:
792             count = count[
793                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
794             ]
795
796         return nb_total, nb_correct, count
797
798     def produce_results(
799         self, n_epoch, model, result_dir, logger, deterministic_synthesis
800     ):
801         train_nb_total, train_nb_correct, count = self.compute_error(
802             model,
803             "train",
804             nb_to_use=1000,
805             deterministic_synthesis=deterministic_synthesis,
806         )
807         logger(
808             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
809         )
810
811         test_nb_total, test_nb_correct, count = self.compute_error(
812             model,
813             "test",
814             nb_to_use=1000,
815             deterministic_synthesis=deterministic_synthesis,
816         )
817         logger(
818             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
819         )
820
821         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
822
823         if count is not None:
824             proportion_optimal = count.diagonal().sum().float() / count.sum()
825             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
826             with open(
827                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
828             ) as f:
829                 for i in range(count.size(0)):
830                     for j in range(count.size(1)):
831                         eol = " " if j < count.size(1) - 1 else "\n"
832                         f.write(f"{count[i,j]}{eol}")
833
834         input = self.test_input[:48]
835         result = input.clone()
836         ar_mask = result.new_zeros(result.size())
837         ar_mask[:, self.height * self.width :] = 1
838         result *= 1 - ar_mask
839         masked_inplace_autoregression(
840             model,
841             self.batch_size,
842             result,
843             ar_mask,
844             deterministic_synthesis,
845             device=self.device,
846         )
847
848         mazes, paths = self.seq2map(input)
849         _, predicted_paths = self.seq2map(result)
850
851         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
852         maze.save_image(
853             filename,
854             mazes=mazes,
855             target_paths=paths,
856             predicted_paths=predicted_paths,
857             path_correct=maze.path_correctness(mazes, predicted_paths),
858             path_optimal=maze.path_optimality(paths, predicted_paths),
859         )
860         logger(f"wrote {filename}")
861
862
863 ######################################################################
864
865
866 import snake
867
868
869 class Snake(Task):
870     def __init__(
871         self,
872         nb_train_samples,
873         nb_test_samples,
874         batch_size,
875         height,
876         width,
877         nb_colors,
878         length,
879         prompt_length,
880         device=torch.device("cpu"),
881     ):
882         super().__init__()
883
884         self.batch_size = batch_size
885         self.height = height
886         self.width = width
887         self.device = device
888         self.prompt_length = prompt_length
889
890         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
891             nb_train_samples,
892             height,
893             width,
894             nb_colors,
895             length,
896             prompt_length,
897             self.device,
898         )
899         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
900             nb_test_samples,
901             height,
902             width,
903             nb_colors,
904             length,
905             prompt_length,
906             self.device,
907         )
908
909         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
910
911     def batches(self, split="train", nb_to_use=-1, desc=None):
912         assert split in {"train", "test"}
913         input = self.train_input if split == "train" else self.test_input
914         if nb_to_use > 0:
915             input = input[:nb_to_use]
916         if desc is None:
917             desc = f"epoch-{split}"
918         for batch in tqdm.tqdm(
919             input.split(self.batch_size), dynamic_ncols=True, desc=desc
920         ):
921             yield batch
922
923     def vocabulary_size(self):
924         return self.nb_codes
925
926     def produce_results(
927         self, n_epoch, model, result_dir, logger, deterministic_synthesis
928     ):
929         def compute_nb_correct(input, prior_visits):
930             result = input.clone()
931             i = torch.arange(result.size(1), device=result.device)[None, :]
932             ar_mask = (
933                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
934                 .long()
935                 .expand_as(result)
936             )
937             result *= 1 - ar_mask
938
939             masked_inplace_autoregression(
940                 model,
941                 self.batch_size,
942                 result,
943                 ar_mask,
944                 deterministic_synthesis,
945                 device=self.device,
946             )
947
948             nb_total = ((prior_visits > 0) * ar_mask).sum()
949
950             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
951
952             return nb_total, nb_correct
953
954         test_nb_total, test_nb_correct = compute_nb_correct(
955             self.test_input[:1000], self.test_prior_visits[:1000]
956         )
957
958         logger(
959             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
960         )
961
962         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
963
964
965 ######################################################################
966
967
968 import stack
969
970
971 class Stack(Task):
972     def __init__(
973         self,
974         nb_train_samples,
975         nb_test_samples,
976         batch_size,
977         logger,
978         nb_steps,
979         nb_stacks,
980         nb_digits,
981         fraction_values_for_train=None,
982         device=torch.device("cpu"),
983     ):
984         super().__init__()
985
986         self.batch_size = batch_size
987         self.nb_steps = nb_steps
988         self.nb_stacks = nb_stacks
989         self.nb_digits = nb_digits
990         self.device = device
991
992         if fraction_values_for_train is None:
993             values_for_train = None
994             values_for_test = None
995         else:
996             all = torch.randperm(10**nb_digits)
997             nb_for_train = int(all.size(0) * fraction_values_for_train)
998             values_for_train = all[:nb_for_train]
999             values_for_test = all[nb_for_train:]
1000
1001         self.train_input, self.train_stack_counts = stack.generate_sequences(
1002             nb_train_samples,
1003             nb_steps,
1004             nb_stacks,
1005             nb_digits,
1006             values_for_train,
1007             self.device,
1008         )
1009
1010         self.test_input, self.test_stack_counts = stack.generate_sequences(
1011             nb_test_samples,
1012             nb_steps,
1013             nb_stacks,
1014             nb_digits,
1015             values_for_test,
1016             self.device,
1017         )
1018
1019         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
1020         counts = self.test_stack_counts.flatten()[i.flatten()]
1021         counts = F.one_hot(counts).sum(0)
1022         logger(f"test_pop_stack_counts {counts}")
1023
1024         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1025
1026     def batches(self, split="train", nb_to_use=-1, desc=None):
1027         assert split in {"train", "test"}
1028         input = self.train_input if split == "train" else self.test_input
1029         if nb_to_use > 0:
1030             input = input[:nb_to_use]
1031         if desc is None:
1032             desc = f"epoch-{split}"
1033         for batch in tqdm.tqdm(
1034             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1035         ):
1036             yield batch
1037
1038     def vocabulary_size(self):
1039         return self.nb_codes
1040
1041     def produce_results(
1042         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1043     ):
1044         def compute_nb_correct(input):
1045             result = input.clone()
1046             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1047             ar_mask = (result != input).long()
1048             masked_inplace_autoregression(
1049                 model,
1050                 self.batch_size,
1051                 result,
1052                 ar_mask,
1053                 deterministic_synthesis,
1054                 device=self.device,
1055             )
1056
1057             errors = ((result != input).long() * ar_mask).reshape(
1058                 -1, 1 + self.nb_digits
1059             )
1060             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1061
1062             nb_total = ar_mask.max(1).values.sum()
1063             nb_correct = nb_total - errors.max(1).values.sum()
1064
1065             return nb_total, nb_correct
1066
1067         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1068
1069         logger(
1070             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1071         )
1072
1073         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1074
1075         ##############################################################
1076         # Log a few generated sequences
1077         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1078         result = input.clone()
1079         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1080         ar_mask = (result != input).long()
1081
1082         # for n in range(result.size(0)):
1083         # logger(
1084         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1085         # )
1086
1087         masked_inplace_autoregression(
1088             model,
1089             self.batch_size,
1090             result,
1091             ar_mask,
1092             deterministic_synthesis,
1093             device=self.device,
1094         )
1095
1096         for n in range(result.size(0)):
1097             logger(
1098                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1099             )
1100         ##############################################################
1101
1102
1103 ######################################################################
1104
1105 import rpl
1106
1107
1108 class RPL(Task):
1109     def tensorize(self, sequences):
1110         len_max = max([len(x) for x in sequences])
1111         return torch.cat(
1112             [
1113                 torch.tensor(
1114                     [
1115                         [
1116                             self.token2id[str(c)]
1117                             for c in s + ["<nul>"] * (len_max - len(s))
1118                         ]
1119                         for s in sequences
1120                     ]
1121                 )
1122             ],
1123             0,
1124         )
1125
1126     def seq2str(self, seq):
1127         return " ".join([self.id2token[i] for i in seq])
1128
1129     def __init__(
1130         self,
1131         nb_train_samples,
1132         nb_test_samples,
1133         batch_size,
1134         nb_starting_values=3,
1135         max_input=9,
1136         prog_len=6,
1137         nb_runs=5,
1138         no_prog=False,
1139         logger=None,
1140         device=torch.device("cpu"),
1141     ):
1142         super().__init__()
1143
1144         self.batch_size = batch_size
1145         self.device = device
1146         self.no_prog = no_prog
1147
1148         train_sequences = [
1149             rpl.generate(
1150                 nb_starting_values=nb_starting_values,
1151                 nb_result_values_max=4 * nb_starting_values,
1152                 max_input=max_input,
1153                 prog_len=prog_len,
1154                 nb_runs=nb_runs,
1155             )
1156             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1157         ]
1158
1159         test_sequences = [
1160             rpl.generate(
1161                 nb_starting_values=nb_starting_values,
1162                 nb_result_values_max=4 * nb_starting_values,
1163                 max_input=max_input,
1164                 prog_len=prog_len,
1165                 nb_runs=nb_runs,
1166             )
1167             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1168         ]
1169
1170         symbols = list(
1171             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1172         )
1173         val_max = max([x if type(x) is int else 0 for x in symbols])
1174         symbols = list(filter(lambda x: type(x) is str, symbols))
1175         symbols.sort()
1176         symbols += [str(n) for n in range(val_max + 1)]
1177         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1178         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1179
1180         self.t_nul = self.token2id["<nul>"]
1181         self.t_input = self.token2id["<in>"]
1182         self.t_output = self.token2id["<out>"]
1183         self.t_prog = self.token2id["<prg>"]
1184         self.t_end = self.token2id["<end>"]
1185
1186         self.train_input = self.tensorize(train_sequences)
1187         self.test_input = self.tensorize(test_sequences)
1188
1189         if no_prog:
1190             # Excise the program from every train and test example
1191             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1192                 None, :
1193             ]
1194             p = (
1195                 ((self.train_input == self.t_prog).long() * k)
1196                 .max(1, keepdim=True)
1197                 .values
1198             )
1199             self.train_input = (
1200                 self.train_input * (k <= p).long()
1201                 + self.t_end * (k == p + 1).long()
1202                 + self.t_nul * (k > p + 1).long()
1203             )
1204             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1205                 None, :
1206             ]
1207             p = (
1208                 ((self.test_input == self.t_prog).long() * k)
1209                 .max(1, keepdim=True)
1210                 .values
1211             )
1212             self.test_input = (
1213                 self.test_input * (k <= p).long()
1214                 + self.t_end * (k == p + 1).long()
1215                 + self.t_nul * (k > p + 1).long()
1216             )
1217
1218         if logger is not None:
1219             logger(f"value_max {val_max}")
1220             for x in self.train_input[:25]:
1221                 end = (x != self.t_nul).nonzero().max().item() + 1
1222                 seq = [self.id2token[i.item()] for i in x[:end]]
1223                 s = " ".join(seq)
1224                 logger(f"example_seq {s}")
1225
1226         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1227
1228     def batches(self, split="train", nb_to_use=-1, desc=None):
1229         assert split in {"train", "test"}
1230         input = self.train_input if split == "train" else self.test_input
1231         if nb_to_use > 0:
1232             input = input[:nb_to_use]
1233         if desc is None:
1234             desc = f"epoch-{split}"
1235         for batch in tqdm.tqdm(
1236             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1237         ):
1238             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1239             batch = batch[:, :last].to(self.device)
1240             yield batch
1241
1242     def vocabulary_size(self):
1243         return self.nb_codes
1244
1245     def produce_results(
1246         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1247     ):
1248         # --------------------------------------------------------------------
1249         def compute_nb_errors_prog(input, nb_to_log=0):
1250             result = input.clone()
1251             s = (result == self.t_prog).long()
1252             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1253             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1254
1255             masked_inplace_autoregression(
1256                 model,
1257                 self.batch_size,
1258                 result,
1259                 ar_mask,
1260                 deterministic_synthesis,
1261                 device=self.device,
1262             )
1263
1264             sum_nb_total, sum_nb_errors = 0, 0
1265             for one_input, one_result in zip(input, result):
1266                 seq = [self.id2token[i.item()] for i in one_result]
1267                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1268                 sum_nb_total += 1
1269                 sum_nb_errors += 0 if nb_errors == 0 else 1
1270                 if nb_to_log > 0:
1271                     gt_seq = [self.id2token[i.item()] for i in one_input]
1272                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1273                     gt_prog = " ".join([str(x) for x in gt_prog])
1274                     prog = " ".join([str(x) for x in prog])
1275                     comment = "*" if nb_errors == 0 else "-"
1276                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1277                     for start_stack, target_stack, result_stack, correct in stacks:
1278                         comment = "*" if correct else "-"
1279                         start_stack = " ".join([str(x) for x in start_stack])
1280                         target_stack = " ".join([str(x) for x in target_stack])
1281                         result_stack = " ".join([str(x) for x in result_stack])
1282                         logger(
1283                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1284                         )
1285                     nb_to_log -= 1
1286
1287             return sum_nb_total, sum_nb_errors
1288
1289         # --------------------------------------------------------------------
1290         def compute_nb_errors_output(input, nb_to_log=0):
1291             result = input.clone()
1292             k = torch.arange(result.size(1), device=result.device)[None, :]
1293             last_output_idx = (
1294                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1295             )
1296             first_prog_idx = (
1297                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1298             )
1299             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1300             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1301
1302             masked_inplace_autoregression(
1303                 model,
1304                 self.batch_size,
1305                 result,
1306                 ar_mask,
1307                 deterministic_synthesis,
1308                 device=self.device,
1309             )
1310
1311             sum_nb_total, sum_nb_errors = 0, 0
1312             for one_input, one_result, i, j in zip(
1313                 input, result, last_output_idx, first_prog_idx
1314             ):
1315                 seq = [self.id2token[i.item()] for i in one_result]
1316                 sum_nb_total += 1
1317                 correct = (one_input - one_result).abs().max() == 0
1318                 sum_nb_errors += 0 if correct else 1
1319                 if nb_to_log > 0:
1320                     result_stack = [
1321                         self.id2token[i.item()] for i in one_result[i : j + 1]
1322                     ]
1323                     target_stack = [
1324                         self.id2token[i.item()] for i in one_input[i : j + 1]
1325                     ]
1326                     comment = "*" if correct else "-"
1327                     result_stack = " ".join([str(x) for x in result_stack])
1328                     target_stack = " ".join([str(x) for x in target_stack])
1329                     logger(
1330                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1331                     )
1332                     nb_to_log -= 1
1333
1334             return sum_nb_total, sum_nb_errors
1335
1336         # --------------------------------------------------------------------
1337
1338         if not self.no_prog:
1339             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1340                 self.test_input[:1000].to(self.device), nb_to_log=10
1341             )
1342
1343             logger(
1344                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1345             )
1346
1347             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1348
1349         test_nb_total, test_nb_errors = compute_nb_errors_output(
1350             self.test_input[:1000].to(self.device), nb_to_log=10
1351         )
1352
1353         logger(
1354             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1355         )
1356
1357         if save_attention_image is None:
1358             logger("no save_attention_image (is pycairo installed?)")
1359         else:
1360             ns = torch.randint(self.test_input.size(0), (1,)).item()
1361             input = self.test_input[ns : ns + 1].clone()
1362             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1363             input = input[:, :last].to(self.device)
1364
1365             with torch.autograd.no_grad():
1366                 t = model.training
1367                 model.eval()
1368                 model.record_attention(True)
1369                 model(BracketedSequence(input))
1370                 model.train(t)
1371                 ram = model.retrieve_attention()
1372                 model.record_attention(False)
1373
1374             tokens_output = [self.id2token[i.item()] for i in input[0]]
1375             tokens_input = ["n/a"] + tokens_output[:-1]
1376             for n_head in range(ram[0].size(1)):
1377                 filename = os.path.join(
1378                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1379                 )
1380                 attention_matrices = [m[0, n_head] for m in ram]
1381                 save_attention_image(
1382                     filename,
1383                     tokens_input,
1384                     tokens_output,
1385                     attention_matrices,
1386                     k_top=10,
1387                     # min_total_attention=0.9,
1388                     token_gap=12,
1389                     layer_gap=50,
1390                 )
1391                 logger(f"wrote {filename}")
1392
1393
1394 ######################################################################
1395
1396
1397 import expr
1398
1399
1400 class Expr(Task):
1401     def tensorize(self, sequences):
1402         len_max = max([len(x) for x in sequences])
1403         return torch.cat(
1404             [
1405                 torch.tensor(
1406                     [
1407                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1408                         for s in sequences
1409                     ]
1410                 )
1411             ],
1412             0,
1413         ).to(self.device)
1414
1415     def __init__(
1416         self,
1417         nb_train_samples,
1418         nb_test_samples,
1419         nb_variables,
1420         sequence_length,
1421         operand_max,
1422         result_max,
1423         batch_size,
1424         device=torch.device("cpu"),
1425     ):
1426         super().__init__()
1427
1428         self.batch_size = batch_size
1429         self.device = device
1430
1431         train_sequences = expr.generate_sequences(
1432             nb_train_samples,
1433             nb_variables=nb_variables,
1434             length=sequence_length,
1435             operand_max=operand_max,
1436             result_max=result_max,
1437         )
1438
1439         test_sequences = expr.generate_sequences(
1440             nb_test_samples,
1441             nb_variables=nb_variables,
1442             length=sequence_length,
1443             operand_max=operand_max,
1444             result_max=result_max,
1445         )
1446
1447         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1448         symbols.sort()
1449
1450         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1451         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1452
1453         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1454
1455         self.train_input = self.tensorize(train_sequences)
1456         self.test_input = self.tensorize(test_sequences)
1457
1458         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1459
1460     def batches(self, split="train", nb_to_use=-1, desc=None):
1461         assert split in {"train", "test"}
1462         input = self.train_input if split == "train" else self.test_input
1463         if nb_to_use > 0:
1464             input = input[:nb_to_use]
1465         if desc is None:
1466             desc = f"epoch-{split}"
1467         for batch in tqdm.tqdm(
1468             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1469         ):
1470             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1471             batch = batch[:, :last]
1472             yield batch
1473
1474     def vocabulary_size(self):
1475         return self.nb_codes
1476
1477     def seq2str(self, s):
1478         return "".join([self.id2char[k.item()] for k in s])
1479
1480     def produce_results(
1481         self,
1482         n_epoch,
1483         model,
1484         result_dir,
1485         logger,
1486         deterministic_synthesis,
1487         input_file=None,
1488     ):
1489         def compute_nb_correct(input):
1490             result = input.clone()
1491             s = (result == self.space).long()
1492             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1493             result = (1 - ar_mask) * result + ar_mask * self.filler
1494             masked_inplace_autoregression(
1495                 model,
1496                 self.batch_size,
1497                 result,
1498                 ar_mask,
1499                 deterministic_synthesis,
1500                 device=self.device,
1501             )
1502
1503             nb_total = input.size(0)
1504             nb_correct = (input == result).long().min(1).values.sum()
1505
1506             #######################################################################
1507             # Comput predicted vs. true variable values
1508
1509             nb_delta = torch.zeros(5, dtype=torch.int64)
1510             nb_missed = 0
1511
1512             values_input = expr.extract_results([self.seq2str(s) for s in input])
1513             values_result = expr.extract_results([self.seq2str(s) for s in result])
1514
1515             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1516
1517             with open(filename, "w") as f:
1518                 for i, r in zip(values_input, values_result):
1519                     for n, vi in i.items():
1520                         vr = r.get(n)
1521                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1522
1523                         if vr is None or vr < 0:
1524                             nb_missed += 1
1525                         else:
1526                             d = abs(vr - vi)
1527                             if d >= nb_delta.size(0):
1528                                 nb_missed += 1
1529                             else:
1530                                 nb_delta[d] += 1
1531
1532             ######################################################################
1533
1534             return nb_total, nb_correct, nb_delta, nb_missed
1535
1536         (
1537             test_nb_total,
1538             test_nb_correct,
1539             test_nb_delta,
1540             test_nb_missed,
1541         ) = compute_nb_correct(self.test_input[:10000])
1542
1543         logger(
1544             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1545         )
1546
1547         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1548
1549         nb_total = test_nb_delta.sum() + test_nb_missed
1550         for d in range(test_nb_delta.size(0)):
1551             logger(
1552                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1553             )
1554         logger(
1555             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1556         )
1557
1558         ##############################################################
1559         # Log a few generated sequences
1560         if input_file is None:
1561             input = self.test_input[:10]
1562         else:
1563             with open(input_file, "r") as f:
1564                 sequences = [e.strip() for e in f.readlines()]
1565                 sequences = [s + " " + "#" * 50 for s in sequences]
1566                 input = self.tensorize(sequences)
1567
1568         result = input.clone()
1569         s = (result == self.space).long()
1570         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1571         result = (1 - ar_mask) * result + ar_mask * self.filler
1572
1573         for n in range(result.size(0)):
1574             logger(f"test_before {self.seq2str(result[n])}")
1575
1576         masked_inplace_autoregression(
1577             model,
1578             self.batch_size,
1579             result,
1580             ar_mask,
1581             deterministic_synthesis,
1582             device=self.device,
1583         )
1584
1585         correct = (1 - ar_mask) * self.space + ar_mask * input
1586         for n in range(result.size(0)):
1587             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1588             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1589             logger(f"truth       {self.seq2str(correct[n])}")
1590         ##############################################################
1591
1592
1593 ######################################################################
1594
1595 import grid
1596
1597
1598 class Grid(Task):
1599     # Make a tensor from a list of strings
1600     def str2tensor(self, descr):
1601         token_descr = [s.strip().split(" ") for s in descr]
1602         l = max([len(s) for s in token_descr])
1603         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1604         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1605         return torch.tensor(id_descr, device=self.device)
1606
1607     # Make a list of strings from a tensor
1608     def tensor2str(self, x):
1609         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1610
1611     # trim all the tensors in the tuple z to remove as much token from
1612     # left and right in the first tensor. If z is a tuple, all its
1613     # elements are trimed according to the triming for the first
1614     def trim(self, z, token="#"):
1615         n = self.token2id[token]
1616         if type(z) == tuple:
1617             x = z[0]
1618             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1619             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1620             return tuple([t[:, a:b] for t in z])
1621         else:
1622             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1623             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1624             return z[:, a:b]
1625
1626     ######################
1627
1628     def __init__(
1629         self,
1630         nb_train_samples,
1631         nb_test_samples,
1632         batch_size,
1633         size,
1634         fraction_play=0.0,
1635         logger=None,
1636         device=torch.device("cpu"),
1637     ):
1638         super().__init__()
1639
1640         self.device = device
1641         self.batch_size = batch_size
1642         self.grid_factory = grid.GridFactory(size=size)
1643         self.fraction_play = fraction_play
1644
1645         if logger is not None:
1646             logger(
1647                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1648             )
1649
1650         self.train_descr = self.grid_factory.generate_samples(
1651             nb=nb_train_samples,
1652             fraction_play=fraction_play,
1653             progress_bar=lambda r: tqdm.tqdm(r),
1654         )
1655
1656         self.test_descr = self.grid_factory.generate_samples(
1657             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1658         )
1659
1660         if fraction_play > 0:
1661             self.play_descr = self.grid_factory.generate_samples(
1662                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1663             )
1664         else:
1665             self.play_descr = []
1666
1667         # Build the tokenizer
1668         tokens = set()
1669         for d in [self.train_descr, self.test_descr, self.play_descr]:
1670             for s in d:
1671                 for t in s.strip().split(" "):
1672                     tokens.add(t)
1673         # make this set a sorted list to get the same tensors given
1674         # the same descr
1675         tokens = list(tokens)
1676         tokens.sort()
1677         tokens = ["#"] + tokens
1678         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1679         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1680         self.t_nul = self.token2id["#"]
1681         self.t_true = self.token2id["true"]
1682         self.t_false = self.token2id["false"]
1683         self.t_pipe = self.token2id["|"]
1684
1685         # Tokenize the train and test sets
1686         self.train_input = self.str2tensor(self.train_descr)
1687         self.test_input = self.str2tensor(self.test_descr)
1688         self.play_input = (
1689             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1690         )
1691
1692     def batches(self, split="train"):
1693         assert split in {"train", "test"}
1694         input = self.train_input if split == "train" else self.test_input
1695         for batch in tqdm.tqdm(
1696             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1697         ):
1698             yield self.trim(batch)
1699
1700     def vocabulary_size(self):
1701         return len(self.token2id)
1702
1703     def produce_results(
1704         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1705     ):
1706         correct = self.test_input[:1000]
1707         result = correct.clone()
1708         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1709         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1710
1711         logger(f"----------------------------------------------------------")
1712
1713         for e in self.tensor2str(result[:10]):
1714             logger(f"test_before {e}")
1715
1716         masked_inplace_autoregression(
1717             model,
1718             self.batch_size,
1719             result,
1720             ar_mask,
1721             deterministic_synthesis,
1722             device=self.device,
1723         )
1724
1725         logger(f"----------------------------------------------------------")
1726
1727         for e in self.tensor2str(result[:10]):
1728             logger(f"test_after  {e}")
1729
1730         logger(f"----------------------------------------------------------")
1731
1732         nb_total = ar_mask.sum().item()
1733         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1734
1735         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1736         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1737
1738         if self.play_input is not None:
1739             result = self.play_input.clone()
1740             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1741             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1742
1743             logger(f"----------------------------------------------------------")
1744
1745             for e in self.tensor2str(result[:10]):
1746                 logger(f"play_before {e}")
1747
1748             masked_inplace_autoregression(
1749                 model,
1750                 self.batch_size,
1751                 result,
1752                 ar_mask,
1753                 deterministic_synthesis,
1754                 device=self.device,
1755             )
1756
1757             logger(f"----------------------------------------------------------")
1758
1759             for e in self.tensor2str(result[:10]):
1760                 logger(f"play_after  {e}")
1761
1762             logger(f"----------------------------------------------------------")
1763
1764
1765 ######################################################################
1766
1767 import qmlp
1768
1769
1770 class QMLP(Task):
1771     ######################
1772
1773     def __init__(
1774         self,
1775         nb_train_samples,
1776         nb_test_samples,
1777         batch_size,
1778         result_dir,
1779         logger=None,
1780         device=torch.device("cpu"),
1781     ):
1782         super().__init__()
1783
1784         self.device = device
1785         self.batch_size = batch_size
1786         self.nb_samples_per_mlp = 256
1787
1788         if logger is not None:
1789             logger(
1790                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1791             )
1792
1793         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1794             nb_mlps=nb_train_samples + nb_test_samples,
1795             nb_samples=self.nb_samples_per_mlp,
1796             device=self.device,
1797             batch_size=64,
1798             nb_epochs=250,
1799             nb_mlps_per_batch=1024,
1800         )
1801
1802         self.train_input = seq[:nb_train_samples]
1803         self.train_q_test_set = q_test_set[:nb_train_samples]
1804         self.train_ref_test_errors = test_error[:nb_train_samples]
1805         self.test_input = seq[nb_train_samples:]
1806         self.test_q_test_set = q_test_set[nb_train_samples:]
1807         self.test_ref_test_errors = test_error[nb_train_samples:]
1808
1809         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1810         with open(filename, "w") as f:
1811             for e in self.train_ref_test_errors:
1812                 f.write(f"{e}\n")
1813
1814         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1815         with open(filename, "w") as f:
1816             for e in self.test_ref_test_errors:
1817                 f.write(f"{e}\n")
1818
1819         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1820
1821     def batches(self, split="train"):
1822         assert split in {"train", "test"}
1823         input = self.train_input if split == "train" else self.test_input
1824         for batch in tqdm.tqdm(
1825             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1826         ):
1827             yield batch
1828
1829     def vocabulary_size(self):
1830         return self.nb_codes
1831
1832     def produce_results(
1833         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1834     ):
1835         correct = self.test_input[:1000]
1836         result = correct.clone()
1837         ar_mask = (
1838             torch.arange(result.size(1), device=result.device)
1839             > self.nb_samples_per_mlp * 3 + 1
1840         ).long()[None, :]
1841         ar_mask = ar_mask.expand_as(result)
1842         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1843
1844         masked_inplace_autoregression(
1845             model,
1846             self.batch_size,
1847             result,
1848             ar_mask,
1849             deterministic_synthesis,
1850             device=self.device,
1851         )
1852
1853         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1854         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1855         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1856
1857         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1858         with open(filename, "w") as f:
1859             for e in error_test:
1860                 f.write(f"{e}\n")
1861
1862
1863 ######################################################################
1864
1865 import escape
1866
1867
1868 class Escape(Task):
1869     def __init__(
1870         self,
1871         nb_train_samples,
1872         nb_test_samples,
1873         batch_size,
1874         height,
1875         width,
1876         T,
1877         nb_walls,
1878         logger=None,
1879         device=torch.device("cpu"),
1880     ):
1881         super().__init__()
1882
1883         self.batch_size = batch_size
1884         self.device = device
1885         self.height = height
1886         self.width = width
1887
1888         states, actions, rewards = escape.generate_episodes(
1889             nb_train_samples + nb_test_samples, height, width, T, nb_walls
1890         )
1891         seq = escape.episodes2seq(states, actions, rewards, lookahead_delta=T)
1892         # seq = seq[:, seq.size(1) // 3 : 2 * seq.size(1) // 3]
1893         self.train_input = seq[:nb_train_samples].to(self.device)
1894         self.test_input = seq[nb_train_samples:].to(self.device)
1895
1896         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1897
1898     def batches(self, split="train", nb_to_use=-1, desc=None):
1899         assert split in {"train", "test"}
1900         input = self.train_input if split == "train" else self.test_input
1901         if nb_to_use > 0:
1902             input = input[:nb_to_use]
1903         if desc is None:
1904             desc = f"epoch-{split}"
1905         for batch in tqdm.tqdm(
1906             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1907         ):
1908             yield batch
1909
1910     def vocabulary_size(self):
1911         return self.nb_codes
1912
1913     def thinking_autoregression(
1914         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
1915     ):
1916         result = self.test_input[:250].clone()
1917         t = torch.arange(result.size(1), device=result.device)[None, :]
1918
1919         state_len = self.height * self.width
1920         it_len = state_len + 3  # state / action / reward / lookahead_reward
1921
1922         def ar(result, ar_mask):
1923             ar_mask = ar_mask.expand_as(result)
1924             result *= 1 - ar_mask
1925             masked_inplace_autoregression(
1926                 model,
1927                 self.batch_size,
1928                 result,
1929                 ar_mask,
1930                 deterministic_synthesis,
1931                 device=self.device,
1932                 progress_bar_desc=None,
1933             )
1934
1935         # Generate iteration after iteration
1936
1937         for u in tqdm.tqdm(
1938             range(it_len, result.size(1) - it_len + 1, it_len), desc="thinking"
1939         ):
1940             # Put the lookahead reward to either 0 or -1 for the
1941             # current iteration, sample the next state
1942             s = -(torch.rand(result.size(0), device=result.device) < 0.2).long()
1943             result[:, u - 1] = s + 1 + escape.first_lookahead_rewards_code
1944             ar_mask = (t >= u).long() * (t < u + state_len).long()
1945             ar(result, ar_mask)
1946
1947             # Put the lookahead reward to +1 for the current
1948             # iteration, sample the action and reward
1949             s = 1
1950             result[:, u - 1] = s + 1 + escape.first_lookahead_rewards_code
1951             ar_mask = (t >= u + state_len).long() * (t < u + state_len + 2).long()
1952             ar(result, ar_mask)
1953
1954             # Fix the previous lookahead rewards in a consistant state
1955             for v in range(0, u, it_len):
1956                 # Extract the rewards
1957                 r = result[:, range(v + state_len + 1 + it_len, u + it_len - 1, it_len)]
1958                 r = r - escape.first_rewards_code - 1
1959                 a = r.min(dim=1).values
1960                 b = r.max(dim=1).values
1961                 s = (a < 0).long() * a + (a >= 0).long() * b
1962                 result[:, v + state_len + 2] = (
1963                     s + 1 + escape.first_lookahead_rewards_code
1964                 )
1965
1966         # Saving the generated sequences
1967
1968         s, a, r, lr = escape.seq2episodes(
1969             result, self.height, self.width, lookahead=True
1970         )
1971         str = escape.episodes2str(
1972             s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True
1973         )
1974
1975         filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
1976         with open(filename, "w") as f:
1977             f.write(str)
1978             logger(f"wrote {filename}")
1979
1980     def produce_results(
1981         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
1982     ):
1983         result = self.test_input[:250].clone()
1984
1985         # Saving the ground truth
1986
1987         s, a, r, lr = escape.seq2episodes(
1988             result, self.height, self.width, lookahead=True
1989         )
1990         str = escape.episodes2str(
1991             s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True
1992         )
1993
1994         filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
1995         with open(filename, "w") as f:
1996             f.write(str)
1997             logger(f"wrote {filename}")
1998
1999         # Re-generating from the first frame
2000
2001         ar_mask = (
2002             torch.arange(result.size(1), device=result.device)
2003             >= self.height * self.width + 3
2004         ).long()[None, :]
2005         ar_mask = ar_mask.expand_as(result)
2006         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2007
2008         masked_inplace_autoregression(
2009             model,
2010             self.batch_size,
2011             result,
2012             ar_mask,
2013             deterministic_synthesis,
2014             device=self.device,
2015         )
2016
2017         # Saving the generated sequences
2018
2019         s, a, r, lr = escape.seq2episodes(
2020             result, self.height, self.width, lookahead=True
2021         )
2022         str = escape.episodes2str(
2023             s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True
2024         )
2025
2026         filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
2027         with open(filename, "w") as f:
2028             f.write(str)
2029             logger(f"wrote {filename}")
2030
2031         self.thinking_autoregression(
2032             n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
2033         )
2034
2035
2036 ######################################################################