1ea3b5d588d37e4c89ca4d06f74d20356b7a2fcf
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train"):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 class TaskFromFile(Task):
74     def tensorize(self, pairs):
75         len_max = max([len(x[0]) for x in pairs])
76
77         input = torch.cat(
78             [
79                 torch.tensor(
80                     [
81                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
82                         for s in pairs
83                     ]
84                 )
85             ],
86             0,
87         ).to("cpu")
88
89         pred_mask = torch.cat(
90             [
91                 torch.tensor(
92                     [
93                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
94                         for s in pairs
95                     ]
96                 )
97             ],
98             0,
99         ).to("cpu")
100
101         return input, pred_mask
102
103     # trim all the tensors in the tuple z to remove as much token from
104     # left and right in the first tensor. If z is a tuple, all its
105     # elements are trimed according to the triming for the first
106     def trim(self, z, token="#"):
107         n = self.char2id[token]
108         if type(z) == tuple:
109             x = z[0]
110             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
111             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
112             return tuple([t[:, a:b] for t in z])
113         else:
114             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
115             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
116             return z[:, a:b]
117
118     def __init__(
119         self,
120         filename,
121         nb_train_samples,
122         nb_test_samples,
123         batch_size,
124         device=torch.device("cpu"),
125     ):
126         self.batch_size = batch_size
127         self.device = device
128
129         pairs = []
130         with open(filename, "r") as f:
131             for _ in range(nb_train_samples + nb_test_samples):
132                 sequence = f.readline().strip()
133                 pred_mask = f.readline().strip()
134                 assert len(sequence) == len(pred_mask)
135                 assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
136                 pairs.append((sequence, pred_mask))
137
138         symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
139         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
140         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
141
142         self.train_input, self.train_pred_masks = self.tensorize(
143             pairs[:nb_train_samples]
144         )
145         self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
146
147         assert self.train_input.size(0) == nb_train_samples
148         assert self.test_input.size(0) == nb_test_samples
149
150     def batches(self, split="train", nb_to_use=-1, desc=None):
151         assert split in {"train", "test"}
152         input = self.train_input if split == "train" else self.test_input
153         if nb_to_use > 0:
154             input = input[:nb_to_use]
155         if desc is None:
156             desc = f"epoch-{split}"
157         for batch in tqdm.tqdm(
158             input.split(self.batch_size), dynamic_ncols=True, desc=desc
159         ):
160             yield self.trim(batch).to(self.device)
161
162     def vocabulary_size(self):
163         return len(self.char2id)
164
165     def tensor2str(self, t):
166         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
167
168     def produce_results(
169         self, n_epoch, model, result_dir, logger, deterministic_synthesis
170     ):
171         correct = self.trim(self.test_input[:1000]).to(self.device)
172         result = correct.clone()
173         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
174         ar_mask = (pred_mask > 0).long()
175         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
176
177         logger(f"----------------------------------------------------------")
178
179         for e in self.tensor2str(result[:10]):
180             logger(f"test_before {e}")
181
182         masked_inplace_autoregression(
183             model,
184             self.batch_size,
185             result,
186             ar_mask,
187             deterministic_synthesis,
188             device=self.device,
189         )
190
191         logger(f"----------------------------------------------------------")
192
193         for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
194             logger(f"test_after  {e}")
195             logger(f"correct     {c}")
196
197         logger(f"----------------------------------------------------------")
198
199         err_mask = (pred_mask == 2).long()
200         nb_total = err_mask.sum().item()
201         nb_correct = ((correct == result).long() * err_mask).sum().item()
202
203         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
204         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
205
206
207 ####################
208
209 import problems
210
211
212 class SandBox(Task):
213     def __init__(
214         self,
215         problem,
216         nb_train_samples,
217         nb_test_samples,
218         batch_size,
219         logger=None,
220         device=torch.device("cpu"),
221         max_nb_codes=1024,
222     ):
223         super().__init__()
224
225         self.batch_size = batch_size
226         self.device = device
227         self.problem = problem
228
229         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
230             nb_train_samples
231         )
232         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
233             nb_test_samples
234         )
235
236         self.train_input, self.train_ar_mask = self.train_input.to(
237             device
238         ), self.train_ar_mask.to(device)
239         self.test_input, self.test_ar_mask = self.test_input.to(
240             device
241         ), self.test_ar_mask.to(device)
242
243         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
244
245         # A bit of paranoia never hurts
246         assert self.nb_codes <= max_nb_codes
247         assert self.train_input.min() >= 0
248         assert self.test_input.min() >= 0
249         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
250             (0,),
251             (1,),
252             (0, 1),
253         }
254         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
255             (0,),
256             (1,),
257             (0, 1),
258         }
259
260         if logger is not None:
261             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
262                 logger(f"train_sequences {self.problem.seq2str(s)}")
263                 a = "".join(["01"[x.item()] for x in a])
264                 logger(f"                {a}")
265
266     def batches(self, split="train", nb_to_use=-1, desc=None):
267         assert split in {"train", "test"}
268         input = self.train_input if split == "train" else self.test_input
269         if nb_to_use > 0:
270             input = input[:nb_to_use]
271         if desc is None:
272             desc = f"epoch-{split}"
273         for batch in tqdm.tqdm(
274             input.split(self.batch_size), dynamic_ncols=True, desc=desc
275         ):
276             yield batch
277
278     def vocabulary_size(self):
279         return self.nb_codes
280
281     def produce_results(
282         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
283     ):
284         def compute_accuracy(input, ar_mask, logger=None):
285             input, ar_mask = input[:nmax], ar_mask[:nmax]
286             result = input.clone() * (1 - ar_mask)
287
288             masked_inplace_autoregression(
289                 model,
290                 self.batch_size,
291                 result,
292                 ar_mask,
293                 deterministic_synthesis,
294                 progress_bar_desc=None,
295                 device=self.device,
296             )
297
298             log_ground_truth = ar_mask.min() == 0
299
300             if logger is not None:
301                 for sp, st in zip(result[:10], input[:10]):
302                     logger(
303                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
304                     )
305                     if log_ground_truth:
306                         logger(
307                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
308                         )
309
310             nb_total, nb_correct = self.problem.compute_nb_correct(
311                 input, ar_mask, result
312             )
313
314             # nb_total = ar_mask.sum().item()
315             # nb_correct = ((result == input).long() * ar_mask).sum().item()
316
317             return nb_total, nb_correct
318
319         train_nb_total, train_nb_correct = compute_accuracy(
320             self.train_input, self.train_ar_mask
321         )
322
323         logger(
324             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
325         )
326
327         test_nb_total, test_nb_correct = compute_accuracy(
328             self.test_input, self.test_ar_mask, logger
329         )
330
331         logger(
332             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
333         )
334
335         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
336
337         if save_attention_image is not None:
338             for k in range(10):
339                 ns = torch.randint(self.test_input.size(0), (1,)).item()
340                 input = self.test_input[ns : ns + 1].clone()
341
342                 with torch.autograd.no_grad():
343                     t = model.training
344                     model.eval()
345                     # model.record_attention(True)
346                     model(BracketedSequence(input))
347                     model.train(t)
348                     # ram = model.retrieve_attention()
349                     # model.record_attention(False)
350
351                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
352                 # tokens_input = ["n/a"] + tokens_output[:-1]
353                 # for n_head in range(ram[0].size(1)):
354                 # filename = os.path.join(
355                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
356                 # )
357                 # attention_matrices = [m[0, n_head] for m in ram]
358                 # save_attention_image(
359                 # filename,
360                 # tokens_input,
361                 # tokens_output,
362                 # attention_matrices,
363                 # k_top=10,
364                 ##min_total_attention=0.9,
365                 # token_gap=12,
366                 # layer_gap=50,
367                 # )
368                 # logger(f"wrote {filename}")
369
370
371 ######################################################################
372
373 import picoclvr
374
375
376 class PicoCLVR(Task):
377     # Make a tensor from a list of strings
378     def tensorize(self, descr):
379         token_descr = [s.strip().split(" ") for s in descr]
380         l = max([len(s) for s in token_descr])
381         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
382         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
383         return torch.tensor(id_descr, device=self.device)
384
385     # Make a list of strings from a tensor
386     def detensorize(self, x):
387         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
388
389     # trim all the tensors in the tuple z to remove as much token from
390     # left and right in the first tensor. If z is a tuple, all its
391     # elements are trimed according to the triming for the first
392     def trim(self, z, token="<nul>"):
393         n = self.token2id[token]
394         if type(z) == tuple:
395             x = z[0]
396             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
397             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
398             return tuple([t[:, a:b] for t in z])
399         else:
400             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
401             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
402             return z[:, a:b]
403
404     ######################
405
406     def __init__(
407         self,
408         nb_train_samples,
409         nb_test_samples,
410         batch_size,
411         height,
412         width,
413         nb_colors=5,
414         logger=None,
415         device=torch.device("cpu"),
416         pruner_train=None,
417         pruner_eval=None,
418     ):
419         super().__init__()
420
421         def generate_descr(nb, cache_suffix, pruner):
422             return picoclvr.generate(
423                 nb,
424                 height=self.height,
425                 width=self.width,
426                 nb_colors=nb_colors,
427                 pruner=pruner,
428             )
429
430         self.height = height
431         self.width = width
432         self.batch_size = batch_size
433         self.device = device
434         self.pruner_train = pruner_train
435         self.pruner_eval = pruner_eval
436
437         if logger is not None:
438             logger(
439                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
440             )
441
442         self.train_descr = generate_descr(
443             nb_train_samples, "train", pruner=self.pruner_train
444         )
445         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
446
447         # Build the tokenizer
448         tokens = {"<nul>", "<img>"}
449         for d in [self.train_descr, self.test_descr]:
450             for s in d:
451                 for t in s.strip().split(" "):
452                     tokens.add(t)
453         # make this set a sorted list to get the same tensors given
454         # the same descr
455         tokens = list(tokens)
456         tokens.sort()
457         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
458         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
459         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
460
461         # Tokenize the train and test sets
462         self.train_input = self.tensorize(self.train_descr)
463         self.test_input = self.tensorize(self.test_descr)
464
465     def batches(self, split="train"):
466         assert split in {"train", "test"}
467         input = self.train_input if split == "train" else self.test_input
468         for batch in tqdm.tqdm(
469             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
470         ):
471             yield self.trim(batch)
472
473     def vocabulary_size(self):
474         return len(self.token2id)
475
476     def compute_missing_properties(
477         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
478     ):
479         acc_nb_requested_properties = []
480         acc_nb_missing_properties = []
481         acc_nb_results = 0
482
483         for input in tqdm.tqdm(
484             self.test_input.split(self.batch_size),
485             dynamic_ncols=True,
486             desc=f"test-properties",
487         ):
488             result = input.clone()
489             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
490             result = (1 - ar_mask) * result + ar_mask * self.t_nul
491             masked_inplace_autoregression(
492                 model,
493                 self.batch_size,
494                 result,
495                 ar_mask,
496                 deterministic_synthesis,
497                 progress_bar_desc=None,
498                 device=self.device,
499             )
500
501             result_descr = self.detensorize(result)
502             np = picoclvr.nb_properties(
503                 result_descr,
504                 height=self.height,
505                 width=self.width,
506                 pruner=pruner,
507             )
508             nb_requested_properties, _, nb_missing_properties = zip(*np)
509             acc_nb_requested_properties += nb_requested_properties
510             acc_nb_missing_properties += nb_missing_properties
511             acc_nb_results += len(result_descr)
512
513         nb_requested_properties = sum(acc_nb_requested_properties)
514         nb_missing_properties = sum(acc_nb_missing_properties)
515
516         prefix = "" if pruner is None else "pruned_"
517         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
518         logger(
519             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
520         )
521         logger(
522             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
523         )
524
525         logger(
526             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
527         )
528
529     ######################################################################
530
531     def produce_results(
532         self, n_epoch, model, result_dir, logger, deterministic_synthesis
533     ):
534         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
535
536         if self.pruner_eval is not None:
537             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
538
539         nb_tokens_to_generate = self.height * self.width + 3
540         result_descr = []
541         nb_per_primer = 8
542         primer = []
543
544         for primer_descr in [
545             "red above green <sep> green top <sep> blue right of red",
546             "there is red <sep> there is yellow <sep> there is blue",
547             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
548             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
549         ]:
550             primer += [primer_descr + " <img>"] * nb_per_primer
551
552         result = self.tensorize(primer)
553         fill = result.new_full(
554             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
555         )
556         result = torch.cat((result, fill), 1)
557         ar_mask = (result == self.t_nul).long()
558         masked_inplace_autoregression(
559             model,
560             self.batch_size,
561             result,
562             ar_mask,
563             deterministic_synthesis,
564             device=self.device,
565         )
566         result_descr = self.detensorize(result)
567
568         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
569
570         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
571         acc_nb_results = len(result_descr)
572
573         nb_requested_properties = sum(acc_nb_requested_properties)
574         nb_missing_properties = sum(acc_nb_missing_properties)
575
576         prefix = "demo_"
577         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
578         logger(
579             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
580         )
581         logger(
582             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
583         )
584
585         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
586
587         if img.dim() == 5:
588             if img.size(1) == 1:
589                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
590             else:
591                 img = torch.cat(
592                     [
593                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
594                         for x in img
595                     ],
596                     0,
597                 )
598
599         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
600         torchvision.utils.save_image(
601             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
602         )
603         logger(f"wrote {image_name}")
604
605
606 ######################################################################
607
608
609 class MNIST(Task):
610     def __init__(
611         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
612     ):
613         super().__init__()
614
615         self.nb_train_samples = (nb_train_samples,)
616         self.nb_test_samples = (nb_test_samples,)
617         self.batch_size = batch_size
618         self.device = device
619         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
620         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
621         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
622         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
623
624     def batches(self, split="train", nb_to_use=-1, desc=None):
625         assert split in {"train", "test"}
626         input = self.train_input if split == "train" else self.test_input
627         if nb_to_use > 0:
628             input = input[:nb_to_use]
629         if desc is None:
630             desc = f"epoch-{split}"
631         for batch in tqdm.tqdm(
632             input.split(self.batch_size), dynamic_ncols=True, desc=desc
633         ):
634             yield batch
635
636     def vocabulary_size(self):
637         return 256
638
639     def produce_results(
640         self, n_epoch, model, result_dir, logger, deterministic_synthesis
641     ):
642         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
643         ar_mask = torch.full_like(results, 1)
644         masked_inplace_autoregression(
645             model,
646             self.batch_size,
647             results,
648             ar_mask,
649             deterministic_synthesis,
650             device=self.device,
651         )
652         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
653         torchvision.utils.save_image(
654             1 - results.reshape(-1, 1, 28, 28) / 255.0,
655             image_name,
656             nrow=16,
657             pad_value=0.8,
658         )
659         logger(f"wrote {image_name}")
660
661
662 ######################################################################
663
664 import maze
665
666
667 class Maze(Task):
668     def map2seq(self, *m):
669         return torch.cat([x.flatten(1) for x in m], 1)
670
671     def seq2map(self, s):
672         s = s.reshape(s.size(0), -1, self.height, self.width)
673         return (s[:, k] for k in range(s.size(1)))
674
675     def __init__(
676         self,
677         nb_train_samples,
678         nb_test_samples,
679         batch_size,
680         height,
681         width,
682         nb_walls,
683         device=torch.device("cpu"),
684     ):
685         super().__init__()
686
687         self.batch_size = batch_size
688         self.height = height
689         self.width = width
690         self.device = device
691
692         train_mazes, train_paths, _ = maze.create_maze_data(
693             nb_train_samples,
694             height=height,
695             width=width,
696             nb_walls=nb_walls,
697             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
698         )
699         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
700
701         test_mazes, test_paths, _ = maze.create_maze_data(
702             nb_test_samples,
703             height=height,
704             width=width,
705             nb_walls=nb_walls,
706             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
707         )
708         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
709
710         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
711
712     def batches(self, split="train", nb_to_use=-1, desc=None):
713         assert split in {"train", "test"}
714         input = self.train_input if split == "train" else self.test_input
715         if nb_to_use > 0:
716             input = input[:nb_to_use]
717         if desc is None:
718             desc = f"epoch-{split}"
719         for batch in tqdm.tqdm(
720             input.split(self.batch_size), dynamic_ncols=True, desc=desc
721         ):
722             yield batch
723
724     def vocabulary_size(self):
725         return self.nb_codes
726
727     def compute_error(
728         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
729     ):
730         nb_total, nb_correct = 0, 0
731         count = torch.zeros(
732             self.width * self.height,
733             self.width * self.height,
734             device=self.device,
735             dtype=torch.int64,
736         )
737
738         for input in self.batches(split, nb_to_use):
739             result = input.clone()
740             ar_mask = result.new_zeros(result.size())
741             ar_mask[:, self.height * self.width :] = 1
742             result *= 1 - ar_mask
743             masked_inplace_autoregression(
744                 model,
745                 self.batch_size,
746                 result,
747                 ar_mask,
748                 deterministic_synthesis,
749                 progress_bar_desc=None,
750                 device=self.device,
751             )
752             mazes, paths = self.seq2map(result)
753             path_correctness = maze.path_correctness(mazes, paths)
754             nb_correct += path_correctness.long().sum()
755             nb_total += mazes.size(0)
756
757             optimal_path_lengths = (
758                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
759             )
760             predicted_path_lengths = (
761                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
762             )
763             optimal_path_lengths = optimal_path_lengths[path_correctness]
764             predicted_path_lengths = predicted_path_lengths[path_correctness]
765             count[optimal_path_lengths, predicted_path_lengths] += 1
766
767         if count.max() == 0:
768             count = None
769         else:
770             count = count[
771                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
772             ]
773
774         return nb_total, nb_correct, count
775
776     def produce_results(
777         self, n_epoch, model, result_dir, logger, deterministic_synthesis
778     ):
779         train_nb_total, train_nb_correct, count = self.compute_error(
780             model,
781             "train",
782             nb_to_use=1000,
783             deterministic_synthesis=deterministic_synthesis,
784         )
785         logger(
786             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
787         )
788
789         test_nb_total, test_nb_correct, count = self.compute_error(
790             model,
791             "test",
792             nb_to_use=1000,
793             deterministic_synthesis=deterministic_synthesis,
794         )
795         logger(
796             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
797         )
798
799         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
800
801         if count is not None:
802             proportion_optimal = count.diagonal().sum().float() / count.sum()
803             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
804             with open(
805                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
806             ) as f:
807                 for i in range(count.size(0)):
808                     for j in range(count.size(1)):
809                         eol = " " if j < count.size(1) - 1 else "\n"
810                         f.write(f"{count[i,j]}{eol}")
811
812         input = self.test_input[:48]
813         result = input.clone()
814         ar_mask = result.new_zeros(result.size())
815         ar_mask[:, self.height * self.width :] = 1
816         result *= 1 - ar_mask
817         masked_inplace_autoregression(
818             model,
819             self.batch_size,
820             result,
821             ar_mask,
822             deterministic_synthesis,
823             device=self.device,
824         )
825
826         mazes, paths = self.seq2map(input)
827         _, predicted_paths = self.seq2map(result)
828
829         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
830         maze.save_image(
831             filename,
832             mazes=mazes,
833             target_paths=paths,
834             predicted_paths=predicted_paths,
835             path_correct=maze.path_correctness(mazes, predicted_paths),
836             path_optimal=maze.path_optimality(paths, predicted_paths),
837         )
838         logger(f"wrote {filename}")
839
840
841 ######################################################################
842
843
844 import snake
845
846
847 class Snake(Task):
848     def __init__(
849         self,
850         nb_train_samples,
851         nb_test_samples,
852         batch_size,
853         height,
854         width,
855         nb_colors,
856         length,
857         prompt_length,
858         device=torch.device("cpu"),
859     ):
860         super().__init__()
861
862         self.batch_size = batch_size
863         self.height = height
864         self.width = width
865         self.device = device
866         self.prompt_length = prompt_length
867
868         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
869             nb_train_samples,
870             height,
871             width,
872             nb_colors,
873             length,
874             prompt_length,
875             self.device,
876         )
877         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
878             nb_test_samples,
879             height,
880             width,
881             nb_colors,
882             length,
883             prompt_length,
884             self.device,
885         )
886
887         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
888
889     def batches(self, split="train", nb_to_use=-1, desc=None):
890         assert split in {"train", "test"}
891         input = self.train_input if split == "train" else self.test_input
892         if nb_to_use > 0:
893             input = input[:nb_to_use]
894         if desc is None:
895             desc = f"epoch-{split}"
896         for batch in tqdm.tqdm(
897             input.split(self.batch_size), dynamic_ncols=True, desc=desc
898         ):
899             yield batch
900
901     def vocabulary_size(self):
902         return self.nb_codes
903
904     def produce_results(
905         self, n_epoch, model, result_dir, logger, deterministic_synthesis
906     ):
907         def compute_nb_correct(input, prior_visits):
908             result = input.clone()
909             i = torch.arange(result.size(1), device=result.device)[None, :]
910             ar_mask = (
911                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
912                 .long()
913                 .expand_as(result)
914             )
915             result *= 1 - ar_mask
916
917             masked_inplace_autoregression(
918                 model,
919                 self.batch_size,
920                 result,
921                 ar_mask,
922                 deterministic_synthesis,
923                 device=self.device,
924             )
925
926             nb_total = ((prior_visits > 0) * ar_mask).sum()
927
928             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
929
930             return nb_total, nb_correct
931
932         test_nb_total, test_nb_correct = compute_nb_correct(
933             self.test_input[:1000], self.test_prior_visits[:1000]
934         )
935
936         logger(
937             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
938         )
939
940         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
941
942
943 ######################################################################
944
945
946 import stack
947
948
949 class Stack(Task):
950     def __init__(
951         self,
952         nb_train_samples,
953         nb_test_samples,
954         batch_size,
955         logger,
956         nb_steps,
957         nb_stacks,
958         nb_digits,
959         fraction_values_for_train=None,
960         device=torch.device("cpu"),
961     ):
962         super().__init__()
963
964         self.batch_size = batch_size
965         self.nb_steps = nb_steps
966         self.nb_stacks = nb_stacks
967         self.nb_digits = nb_digits
968         self.device = device
969
970         if fraction_values_for_train is None:
971             values_for_train = None
972             values_for_test = None
973         else:
974             all = torch.randperm(10**nb_digits)
975             nb_for_train = int(all.size(0) * fraction_values_for_train)
976             values_for_train = all[:nb_for_train]
977             values_for_test = all[nb_for_train:]
978
979         self.train_input, self.train_stack_counts = stack.generate_sequences(
980             nb_train_samples,
981             nb_steps,
982             nb_stacks,
983             nb_digits,
984             values_for_train,
985             self.device,
986         )
987
988         self.test_input, self.test_stack_counts = stack.generate_sequences(
989             nb_test_samples,
990             nb_steps,
991             nb_stacks,
992             nb_digits,
993             values_for_test,
994             self.device,
995         )
996
997         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
998         counts = self.test_stack_counts.flatten()[i.flatten()]
999         counts = F.one_hot(counts).sum(0)
1000         logger(f"test_pop_stack_counts {counts}")
1001
1002         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1003
1004     def batches(self, split="train", nb_to_use=-1, desc=None):
1005         assert split in {"train", "test"}
1006         input = self.train_input if split == "train" else self.test_input
1007         if nb_to_use > 0:
1008             input = input[:nb_to_use]
1009         if desc is None:
1010             desc = f"epoch-{split}"
1011         for batch in tqdm.tqdm(
1012             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1013         ):
1014             yield batch
1015
1016     def vocabulary_size(self):
1017         return self.nb_codes
1018
1019     def produce_results(
1020         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1021     ):
1022         def compute_nb_correct(input):
1023             result = input.clone()
1024             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1025             ar_mask = (result != input).long()
1026             masked_inplace_autoregression(
1027                 model,
1028                 self.batch_size,
1029                 result,
1030                 ar_mask,
1031                 deterministic_synthesis,
1032                 device=self.device,
1033             )
1034
1035             errors = ((result != input).long() * ar_mask).reshape(
1036                 -1, 1 + self.nb_digits
1037             )
1038             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1039
1040             nb_total = ar_mask.max(1).values.sum()
1041             nb_correct = nb_total - errors.max(1).values.sum()
1042
1043             return nb_total, nb_correct
1044
1045         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1046
1047         logger(
1048             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1049         )
1050
1051         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1052
1053         ##############################################################
1054         # Log a few generated sequences
1055         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1056         result = input.clone()
1057         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1058         ar_mask = (result != input).long()
1059
1060         # for n in range(result.size(0)):
1061         # logger(
1062         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1063         # )
1064
1065         masked_inplace_autoregression(
1066             model,
1067             self.batch_size,
1068             result,
1069             ar_mask,
1070             deterministic_synthesis,
1071             device=self.device,
1072         )
1073
1074         for n in range(result.size(0)):
1075             logger(
1076                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1077             )
1078         ##############################################################
1079
1080
1081 ######################################################################
1082
1083 import rpl
1084
1085
1086 class RPL(Task):
1087     def tensorize(self, sequences):
1088         len_max = max([len(x) for x in sequences])
1089         return torch.cat(
1090             [
1091                 torch.tensor(
1092                     [
1093                         [
1094                             self.token2id[str(c)]
1095                             for c in s + ["<nul>"] * (len_max - len(s))
1096                         ]
1097                         for s in sequences
1098                     ]
1099                 )
1100             ],
1101             0,
1102         )
1103
1104     def seq2str(self, seq):
1105         return " ".join([self.id2token[i] for i in seq])
1106
1107     def __init__(
1108         self,
1109         nb_train_samples,
1110         nb_test_samples,
1111         batch_size,
1112         nb_starting_values=3,
1113         max_input=9,
1114         prog_len=6,
1115         nb_runs=5,
1116         no_prog=False,
1117         logger=None,
1118         device=torch.device("cpu"),
1119     ):
1120         super().__init__()
1121
1122         self.batch_size = batch_size
1123         self.device = device
1124         self.no_prog = no_prog
1125
1126         train_sequences = [
1127             rpl.generate(
1128                 nb_starting_values=nb_starting_values,
1129                 nb_result_values_max=4 * nb_starting_values,
1130                 max_input=max_input,
1131                 prog_len=prog_len,
1132                 nb_runs=nb_runs,
1133             )
1134             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1135         ]
1136
1137         test_sequences = [
1138             rpl.generate(
1139                 nb_starting_values=nb_starting_values,
1140                 nb_result_values_max=4 * nb_starting_values,
1141                 max_input=max_input,
1142                 prog_len=prog_len,
1143                 nb_runs=nb_runs,
1144             )
1145             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1146         ]
1147
1148         symbols = list(
1149             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1150         )
1151         val_max = max([x if type(x) is int else 0 for x in symbols])
1152         symbols = list(filter(lambda x: type(x) is str, symbols))
1153         symbols.sort()
1154         symbols += [str(n) for n in range(val_max + 1)]
1155         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1156         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1157
1158         self.t_nul = self.token2id["<nul>"]
1159         self.t_input = self.token2id["<in>"]
1160         self.t_output = self.token2id["<out>"]
1161         self.t_prog = self.token2id["<prg>"]
1162         self.t_end = self.token2id["<end>"]
1163
1164         self.train_input = self.tensorize(train_sequences)
1165         self.test_input = self.tensorize(test_sequences)
1166
1167         if no_prog:
1168             # Excise the program from every train and test example
1169             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1170                 None, :
1171             ]
1172             p = (
1173                 ((self.train_input == self.t_prog).long() * k)
1174                 .max(1, keepdim=True)
1175                 .values
1176             )
1177             self.train_input = (
1178                 self.train_input * (k <= p).long()
1179                 + self.t_end * (k == p + 1).long()
1180                 + self.t_nul * (k > p + 1).long()
1181             )
1182             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1183                 None, :
1184             ]
1185             p = (
1186                 ((self.test_input == self.t_prog).long() * k)
1187                 .max(1, keepdim=True)
1188                 .values
1189             )
1190             self.test_input = (
1191                 self.test_input * (k <= p).long()
1192                 + self.t_end * (k == p + 1).long()
1193                 + self.t_nul * (k > p + 1).long()
1194             )
1195
1196         if logger is not None:
1197             logger(f"value_max {val_max}")
1198             for x in self.train_input[:25]:
1199                 end = (x != self.t_nul).nonzero().max().item() + 1
1200                 seq = [self.id2token[i.item()] for i in x[:end]]
1201                 s = " ".join(seq)
1202                 logger(f"example_seq {s}")
1203
1204         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1205
1206     def batches(self, split="train", nb_to_use=-1, desc=None):
1207         assert split in {"train", "test"}
1208         input = self.train_input if split == "train" else self.test_input
1209         if nb_to_use > 0:
1210             input = input[:nb_to_use]
1211         if desc is None:
1212             desc = f"epoch-{split}"
1213         for batch in tqdm.tqdm(
1214             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1215         ):
1216             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1217             batch = batch[:, :last].to(self.device)
1218             yield batch
1219
1220     def vocabulary_size(self):
1221         return self.nb_codes
1222
1223     def produce_results(
1224         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1225     ):
1226         # --------------------------------------------------------------------
1227         def compute_nb_errors_prog(input, nb_to_log=0):
1228             result = input.clone()
1229             s = (result == self.t_prog).long()
1230             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1231             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1232
1233             masked_inplace_autoregression(
1234                 model,
1235                 self.batch_size,
1236                 result,
1237                 ar_mask,
1238                 deterministic_synthesis,
1239                 device=self.device,
1240             )
1241
1242             sum_nb_total, sum_nb_errors = 0, 0
1243             for one_input, one_result in zip(input, result):
1244                 seq = [self.id2token[i.item()] for i in one_result]
1245                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1246                 sum_nb_total += 1
1247                 sum_nb_errors += 0 if nb_errors == 0 else 1
1248                 if nb_to_log > 0:
1249                     gt_seq = [self.id2token[i.item()] for i in one_input]
1250                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1251                     gt_prog = " ".join([str(x) for x in gt_prog])
1252                     prog = " ".join([str(x) for x in prog])
1253                     comment = "*" if nb_errors == 0 else "-"
1254                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1255                     for start_stack, target_stack, result_stack, correct in stacks:
1256                         comment = "*" if correct else "-"
1257                         start_stack = " ".join([str(x) for x in start_stack])
1258                         target_stack = " ".join([str(x) for x in target_stack])
1259                         result_stack = " ".join([str(x) for x in result_stack])
1260                         logger(
1261                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1262                         )
1263                     nb_to_log -= 1
1264
1265             return sum_nb_total, sum_nb_errors
1266
1267         # --------------------------------------------------------------------
1268         def compute_nb_errors_output(input, nb_to_log=0):
1269             result = input.clone()
1270             k = torch.arange(result.size(1), device=result.device)[None, :]
1271             last_output_idx = (
1272                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1273             )
1274             first_prog_idx = (
1275                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1276             )
1277             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1278             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1279
1280             masked_inplace_autoregression(
1281                 model,
1282                 self.batch_size,
1283                 result,
1284                 ar_mask,
1285                 deterministic_synthesis,
1286                 device=self.device,
1287             )
1288
1289             sum_nb_total, sum_nb_errors = 0, 0
1290             for one_input, one_result, i, j in zip(
1291                 input, result, last_output_idx, first_prog_idx
1292             ):
1293                 seq = [self.id2token[i.item()] for i in one_result]
1294                 sum_nb_total += 1
1295                 correct = (one_input - one_result).abs().max() == 0
1296                 sum_nb_errors += 0 if correct else 1
1297                 if nb_to_log > 0:
1298                     result_stack = [
1299                         self.id2token[i.item()] for i in one_result[i : j + 1]
1300                     ]
1301                     target_stack = [
1302                         self.id2token[i.item()] for i in one_input[i : j + 1]
1303                     ]
1304                     comment = "*" if correct else "-"
1305                     result_stack = " ".join([str(x) for x in result_stack])
1306                     target_stack = " ".join([str(x) for x in target_stack])
1307                     logger(
1308                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1309                     )
1310                     nb_to_log -= 1
1311
1312             return sum_nb_total, sum_nb_errors
1313
1314         # --------------------------------------------------------------------
1315
1316         if not self.no_prog:
1317             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1318                 self.test_input[:1000].to(self.device), nb_to_log=10
1319             )
1320
1321             logger(
1322                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1323             )
1324
1325             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1326
1327         test_nb_total, test_nb_errors = compute_nb_errors_output(
1328             self.test_input[:1000].to(self.device), nb_to_log=10
1329         )
1330
1331         logger(
1332             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1333         )
1334
1335         if save_attention_image is None:
1336             logger("no save_attention_image (is pycairo installed?)")
1337         else:
1338             ns = torch.randint(self.test_input.size(0), (1,)).item()
1339             input = self.test_input[ns : ns + 1].clone()
1340             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1341             input = input[:, :last].to(self.device)
1342
1343             with torch.autograd.no_grad():
1344                 t = model.training
1345                 model.eval()
1346                 model.record_attention(True)
1347                 model(BracketedSequence(input))
1348                 model.train(t)
1349                 ram = model.retrieve_attention()
1350                 model.record_attention(False)
1351
1352             tokens_output = [self.id2token[i.item()] for i in input[0]]
1353             tokens_input = ["n/a"] + tokens_output[:-1]
1354             for n_head in range(ram[0].size(1)):
1355                 filename = os.path.join(
1356                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1357                 )
1358                 attention_matrices = [m[0, n_head] for m in ram]
1359                 save_attention_image(
1360                     filename,
1361                     tokens_input,
1362                     tokens_output,
1363                     attention_matrices,
1364                     k_top=10,
1365                     # min_total_attention=0.9,
1366                     token_gap=12,
1367                     layer_gap=50,
1368                 )
1369                 logger(f"wrote {filename}")
1370
1371
1372 ######################################################################
1373
1374
1375 import expr
1376
1377
1378 class Expr(Task):
1379     def tensorize(self, sequences):
1380         len_max = max([len(x) for x in sequences])
1381         return torch.cat(
1382             [
1383                 torch.tensor(
1384                     [
1385                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1386                         for s in sequences
1387                     ]
1388                 )
1389             ],
1390             0,
1391         ).to(self.device)
1392
1393     def __init__(
1394         self,
1395         nb_train_samples,
1396         nb_test_samples,
1397         nb_variables,
1398         sequence_length,
1399         operand_max,
1400         result_max,
1401         batch_size,
1402         device=torch.device("cpu"),
1403     ):
1404         super().__init__()
1405
1406         self.batch_size = batch_size
1407         self.device = device
1408
1409         train_sequences = expr.generate_sequences(
1410             nb_train_samples,
1411             nb_variables=nb_variables,
1412             length=sequence_length,
1413             operand_max=operand_max,
1414             result_max=result_max,
1415         )
1416
1417         test_sequences = expr.generate_sequences(
1418             nb_test_samples,
1419             nb_variables=nb_variables,
1420             length=sequence_length,
1421             operand_max=operand_max,
1422             result_max=result_max,
1423         )
1424
1425         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1426         symbols.sort()
1427
1428         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1429         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1430
1431         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1432
1433         self.train_input = self.tensorize(train_sequences)
1434         self.test_input = self.tensorize(test_sequences)
1435
1436         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1437
1438     def batches(self, split="train", nb_to_use=-1, desc=None):
1439         assert split in {"train", "test"}
1440         input = self.train_input if split == "train" else self.test_input
1441         if nb_to_use > 0:
1442             input = input[:nb_to_use]
1443         if desc is None:
1444             desc = f"epoch-{split}"
1445         for batch in tqdm.tqdm(
1446             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1447         ):
1448             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1449             batch = batch[:, :last]
1450             yield batch
1451
1452     def vocabulary_size(self):
1453         return self.nb_codes
1454
1455     def seq2str(self, s):
1456         return "".join([self.id2char[k.item()] for k in s])
1457
1458     def produce_results(
1459         self,
1460         n_epoch,
1461         model,
1462         result_dir,
1463         logger,
1464         deterministic_synthesis,
1465         input_file=None,
1466     ):
1467         def compute_nb_correct(input):
1468             result = input.clone()
1469             s = (result == self.space).long()
1470             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1471             result = (1 - ar_mask) * result + ar_mask * self.filler
1472             masked_inplace_autoregression(
1473                 model,
1474                 self.batch_size,
1475                 result,
1476                 ar_mask,
1477                 deterministic_synthesis,
1478                 device=self.device,
1479             )
1480
1481             nb_total = input.size(0)
1482             nb_correct = (input == result).long().min(1).values.sum()
1483
1484             #######################################################################
1485             # Comput predicted vs. true variable values
1486
1487             nb_delta = torch.zeros(5, dtype=torch.int64)
1488             nb_missed = 0
1489
1490             values_input = expr.extract_results([self.seq2str(s) for s in input])
1491             values_result = expr.extract_results([self.seq2str(s) for s in result])
1492
1493             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1494
1495             with open(filename, "w") as f:
1496                 for i, r in zip(values_input, values_result):
1497                     for n, vi in i.items():
1498                         vr = r.get(n)
1499                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1500
1501                         if vr is None or vr < 0:
1502                             nb_missed += 1
1503                         else:
1504                             d = abs(vr - vi)
1505                             if d >= nb_delta.size(0):
1506                                 nb_missed += 1
1507                             else:
1508                                 nb_delta[d] += 1
1509
1510             ######################################################################
1511
1512             return nb_total, nb_correct, nb_delta, nb_missed
1513
1514         (
1515             test_nb_total,
1516             test_nb_correct,
1517             test_nb_delta,
1518             test_nb_missed,
1519         ) = compute_nb_correct(self.test_input[:10000])
1520
1521         logger(
1522             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1523         )
1524
1525         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1526
1527         nb_total = test_nb_delta.sum() + test_nb_missed
1528         for d in range(test_nb_delta.size(0)):
1529             logger(
1530                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1531             )
1532         logger(
1533             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1534         )
1535
1536         ##############################################################
1537         # Log a few generated sequences
1538         if input_file is None:
1539             input = self.test_input[:10]
1540         else:
1541             with open(input_file, "r") as f:
1542                 sequences = [e.strip() for e in f.readlines()]
1543                 sequences = [s + " " + "#" * 50 for s in sequences]
1544                 input = self.tensorize(sequences)
1545
1546         result = input.clone()
1547         s = (result == self.space).long()
1548         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1549         result = (1 - ar_mask) * result + ar_mask * self.filler
1550
1551         for n in range(result.size(0)):
1552             logger(f"test_before {self.seq2str(result[n])}")
1553
1554         masked_inplace_autoregression(
1555             model,
1556             self.batch_size,
1557             result,
1558             ar_mask,
1559             deterministic_synthesis,
1560             device=self.device,
1561         )
1562
1563         correct = (1 - ar_mask) * self.space + ar_mask * input
1564         for n in range(result.size(0)):
1565             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1566             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1567             logger(f"truth       {self.seq2str(correct[n])}")
1568         ##############################################################
1569
1570
1571 ######################################################################
1572
1573 import grid
1574
1575
1576 class Grid(Task):
1577     # Make a tensor from a list of strings
1578     def str2tensor(self, descr):
1579         token_descr = [s.strip().split(" ") for s in descr]
1580         l = max([len(s) for s in token_descr])
1581         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1582         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1583         return torch.tensor(id_descr, device=self.device)
1584
1585     # Make a list of strings from a tensor
1586     def tensor2str(self, x):
1587         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1588
1589     # trim all the tensors in the tuple z to remove as much token from
1590     # left and right in the first tensor. If z is a tuple, all its
1591     # elements are trimed according to the triming for the first
1592     def trim(self, z, token="#"):
1593         n = self.token2id[token]
1594         if type(z) == tuple:
1595             x = z[0]
1596             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1597             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1598             return tuple([t[:, a:b] for t in z])
1599         else:
1600             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1601             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1602             return z[:, a:b]
1603
1604     ######################
1605
1606     def __init__(
1607         self,
1608         nb_train_samples,
1609         nb_test_samples,
1610         batch_size,
1611         size,
1612         fraction_play=0.0,
1613         logger=None,
1614         device=torch.device("cpu"),
1615     ):
1616         super().__init__()
1617
1618         self.device = device
1619         self.batch_size = batch_size
1620         self.grid_factory = grid.GridFactory(size=size)
1621         self.fraction_play = fraction_play
1622
1623         if logger is not None:
1624             logger(
1625                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1626             )
1627
1628         self.train_descr = self.grid_factory.generate_samples(
1629             nb=nb_train_samples,
1630             fraction_play=fraction_play,
1631             progress_bar=lambda r: tqdm.tqdm(r),
1632         )
1633
1634         self.test_descr = self.grid_factory.generate_samples(
1635             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1636         )
1637
1638         if fraction_play > 0:
1639             self.play_descr = self.grid_factory.generate_samples(
1640                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1641             )
1642         else:
1643             self.play_descr = []
1644
1645         # Build the tokenizer
1646         tokens = set()
1647         for d in [self.train_descr, self.test_descr, self.play_descr]:
1648             for s in d:
1649                 for t in s.strip().split(" "):
1650                     tokens.add(t)
1651         # make this set a sorted list to get the same tensors given
1652         # the same descr
1653         tokens = list(tokens)
1654         tokens.sort()
1655         tokens = ["#"] + tokens
1656         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1657         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1658         self.t_nul = self.token2id["#"]
1659         self.t_true = self.token2id["true"]
1660         self.t_false = self.token2id["false"]
1661         self.t_pipe = self.token2id["|"]
1662
1663         # Tokenize the train and test sets
1664         self.train_input = self.str2tensor(self.train_descr)
1665         self.test_input = self.str2tensor(self.test_descr)
1666         self.play_input = (
1667             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1668         )
1669
1670     def batches(self, split="train"):
1671         assert split in {"train", "test"}
1672         input = self.train_input if split == "train" else self.test_input
1673         for batch in tqdm.tqdm(
1674             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1675         ):
1676             yield self.trim(batch)
1677
1678     def vocabulary_size(self):
1679         return len(self.token2id)
1680
1681     def produce_results(
1682         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1683     ):
1684         correct = self.test_input[:1000]
1685         result = correct.clone()
1686         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1687         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1688
1689         logger(f"----------------------------------------------------------")
1690
1691         for e in self.tensor2str(result[:10]):
1692             logger(f"test_before {e}")
1693
1694         masked_inplace_autoregression(
1695             model,
1696             self.batch_size,
1697             result,
1698             ar_mask,
1699             deterministic_synthesis,
1700             device=self.device,
1701         )
1702
1703         logger(f"----------------------------------------------------------")
1704
1705         for e in self.tensor2str(result[:10]):
1706             logger(f"test_after  {e}")
1707
1708         logger(f"----------------------------------------------------------")
1709
1710         nb_total = ar_mask.sum().item()
1711         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1712
1713         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1714         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1715
1716         if self.play_input is not None:
1717             result = self.play_input.clone()
1718             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1719             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1720
1721             logger(f"----------------------------------------------------------")
1722
1723             for e in self.tensor2str(result[:10]):
1724                 logger(f"play_before {e}")
1725
1726             masked_inplace_autoregression(
1727                 model,
1728                 self.batch_size,
1729                 result,
1730                 ar_mask,
1731                 deterministic_synthesis,
1732                 device=self.device,
1733             )
1734
1735             logger(f"----------------------------------------------------------")
1736
1737             for e in self.tensor2str(result[:10]):
1738                 logger(f"play_after  {e}")
1739
1740             logger(f"----------------------------------------------------------")
1741
1742
1743 ######################################################################
1744
1745 import qmlp
1746
1747
1748 class QMLP(Task):
1749     ######################
1750
1751     def __init__(
1752         self,
1753         nb_train_samples,
1754         nb_test_samples,
1755         batch_size,
1756         result_dir,
1757         logger=None,
1758         device=torch.device("cpu"),
1759     ):
1760         super().__init__()
1761
1762         self.device = device
1763         self.batch_size = batch_size
1764         self.nb_samples_per_mlp = 256
1765
1766         if logger is not None:
1767             logger(
1768                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1769             )
1770
1771         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1772             nb_mlps=nb_train_samples + nb_test_samples,
1773             nb_samples=self.nb_samples_per_mlp,
1774             device=self.device,
1775             batch_size=64,
1776             nb_epochs=250,
1777             nb_mlps_per_batch=1024,
1778         )
1779
1780         self.train_input = seq[:nb_train_samples]
1781         self.train_q_test_set = q_test_set[:nb_train_samples]
1782         self.train_ref_test_errors = test_error[:nb_train_samples]
1783         self.test_input = seq[nb_train_samples:]
1784         self.test_q_test_set = q_test_set[nb_train_samples:]
1785         self.test_ref_test_errors = test_error[nb_train_samples:]
1786
1787         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1788         with open(filename, "w") as f:
1789             for e in self.train_ref_test_errors:
1790                 f.write(f"{e}\n")
1791
1792         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1793         with open(filename, "w") as f:
1794             for e in self.test_ref_test_errors:
1795                 f.write(f"{e}\n")
1796
1797         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1798
1799     def batches(self, split="train"):
1800         assert split in {"train", "test"}
1801         input = self.train_input if split == "train" else self.test_input
1802         for batch in tqdm.tqdm(
1803             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1804         ):
1805             yield batch
1806
1807     def vocabulary_size(self):
1808         return self.nb_codes
1809
1810     def produce_results(
1811         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1812     ):
1813         correct = self.test_input[:1000]
1814         result = correct.clone()
1815         ar_mask = (
1816             torch.arange(result.size(1), device=result.device)
1817             > self.nb_samples_per_mlp * 3 + 1
1818         ).long()[None, :]
1819         ar_mask = ar_mask.expand_as(result)
1820         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1821
1822         masked_inplace_autoregression(
1823             model,
1824             self.batch_size,
1825             result,
1826             ar_mask,
1827             deterministic_synthesis,
1828             device=self.device,
1829         )
1830
1831         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1832         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1833         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1834
1835         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1836         with open(filename, "w") as f:
1837             for e in error_test:
1838                 f.write(f"{e}\n")
1839
1840
1841 ######################################################################