Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train"):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 class TaskFromFile(Task):
74     def tensorize(self, pairs):
75         len_max = max([len(x[0]) for x in pairs])
76
77         input = torch.cat(
78             [
79                 torch.tensor(
80                     [
81                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
82                         for s in pairs
83                     ]
84                 )
85             ],
86             0,
87         ).to("cpu")
88
89         pred_mask = torch.cat(
90             [
91                 torch.tensor(
92                     [
93                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
94                         for s in pairs
95                     ]
96                 )
97             ],
98             0,
99         ).to("cpu")
100
101         return input, pred_mask
102
103     # trim all the tensors in the tuple z to remove as much token from
104     # left and right in the first tensor. If z is a tuple, all its
105     # elements are trimed according to the triming for the first
106     def trim(self, z, token="#"):
107         n = self.char2id[token]
108         if type(z) == tuple:
109             x = z[0]
110             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
111             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
112             return tuple([t[:, a:b] for t in z])
113         else:
114             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
115             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
116             return z[:, a:b]
117
118     def __init__(
119         self,
120         filename,
121         nb_train_samples,
122         nb_test_samples,
123         batch_size,
124         device=torch.device("cpu"),
125     ):
126         self.batch_size = batch_size
127         self.device = device
128
129         pairs = []
130         with open(filename, "r") as f:
131             for _ in range(nb_train_samples + nb_test_samples):
132                 sequence = f.readline().strip()
133                 pred_mask = f.readline().strip()
134                 assert len(sequence) == len(pred_mask)
135                 assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}"
136                 pairs.append((sequence, pred_mask))
137
138         symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
139         print("SANITY", symbols)
140         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
141         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
142
143         print(self.char2id)
144
145         self.train_input, self.train_pred_masks = self.tensorize(
146             pairs[:nb_train_samples]
147         )
148         self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
149
150     def batches(self, split="train", nb_to_use=-1, desc=None):
151         assert split in {"train", "test"}
152         input = self.train_input if split == "train" else self.test_input
153         if nb_to_use > 0:
154             input = input[:nb_to_use]
155         if desc is None:
156             desc = f"epoch-{split}"
157         for batch in tqdm.tqdm(
158             input.split(self.batch_size), dynamic_ncols=True, desc=desc
159         ):
160             yield self.trim(batch).to(self.device)
161
162     def vocabulary_size(self):
163         return len(self.char2id)
164
165     def tensor2str(self, t):
166         print(f"{type(t)=}")
167         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
168
169     def produce_results(
170         self, n_epoch, model, result_dir, logger, deterministic_synthesis
171     ):
172         correct = self.trim(self.test_input[:1000]).to(self.device)
173         result = correct.clone()
174         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
175         ar_mask = (pred_mask > 0).long()
176         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
177
178         logger(f"----------------------------------------------------------")
179
180         for e in self.tensor2str(result[:10]):
181             logger(f"test_before {e}")
182
183         masked_inplace_autoregression(
184             model,
185             self.batch_size,
186             result,
187             ar_mask,
188             deterministic_synthesis,
189             device=self.device,
190         )
191
192         logger(f"----------------------------------------------------------")
193
194         for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
195             logger(f"test_after  {e}")
196             logger(f"correct     {c}")
197
198         logger(f"----------------------------------------------------------")
199
200         err_mask = (pred_mask == 2).long()
201         nb_total = err_mask.sum().item()
202         nb_correct = ((correct == result).long() * err_mask).sum().item()
203
204         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
205         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
206
207
208 ####################
209
210 import problems
211
212
213 class SandBox(Task):
214     def __init__(
215         self,
216         problem,
217         nb_train_samples,
218         nb_test_samples,
219         batch_size,
220         logger=None,
221         device=torch.device("cpu"),
222         max_nb_codes=1024,
223     ):
224         super().__init__()
225
226         self.batch_size = batch_size
227         self.device = device
228         self.problem = problem
229
230         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
231             nb_train_samples
232         )
233         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
234             nb_test_samples
235         )
236
237         self.train_input, self.train_ar_mask = self.train_input.to(
238             device
239         ), self.train_ar_mask.to(device)
240         self.test_input, self.test_ar_mask = self.test_input.to(
241             device
242         ), self.test_ar_mask.to(device)
243
244         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
245
246         # A bit of paranoia never hurts
247         assert self.nb_codes <= max_nb_codes
248         assert self.train_input.min() >= 0
249         assert self.test_input.min() >= 0
250         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
251             (0,),
252             (1,),
253             (0, 1),
254         }
255         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
256             (0,),
257             (1,),
258             (0, 1),
259         }
260
261         if logger is not None:
262             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
263                 logger(f"train_sequences {self.problem.seq2str(s)}")
264                 a = "".join(["01"[x.item()] for x in a])
265                 logger(f"                {a}")
266
267     def batches(self, split="train", nb_to_use=-1, desc=None):
268         assert split in {"train", "test"}
269         input = self.train_input if split == "train" else self.test_input
270         if nb_to_use > 0:
271             input = input[:nb_to_use]
272         if desc is None:
273             desc = f"epoch-{split}"
274         for batch in tqdm.tqdm(
275             input.split(self.batch_size), dynamic_ncols=True, desc=desc
276         ):
277             yield batch
278
279     def vocabulary_size(self):
280         return self.nb_codes
281
282     def produce_results(
283         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
284     ):
285         def compute_accuracy(input, ar_mask, logger=None):
286             input, ar_mask = input[:nmax], ar_mask[:nmax]
287             result = input.clone() * (1 - ar_mask)
288
289             masked_inplace_autoregression(
290                 model,
291                 self.batch_size,
292                 result,
293                 ar_mask,
294                 deterministic_synthesis,
295                 progress_bar_desc=None,
296                 device=self.device,
297             )
298
299             log_ground_truth = ar_mask.min() == 0
300
301             if logger is not None:
302                 for sp, st in zip(result[:10], input[:10]):
303                     logger(
304                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
305                     )
306                     if log_ground_truth:
307                         logger(
308                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
309                         )
310
311             nb_total, nb_correct = self.problem.compute_nb_correct(
312                 input, ar_mask, result
313             )
314
315             # nb_total = ar_mask.sum().item()
316             # nb_correct = ((result == input).long() * ar_mask).sum().item()
317
318             return nb_total, nb_correct
319
320         train_nb_total, train_nb_correct = compute_accuracy(
321             self.train_input, self.train_ar_mask
322         )
323
324         logger(
325             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
326         )
327
328         test_nb_total, test_nb_correct = compute_accuracy(
329             self.test_input, self.test_ar_mask, logger
330         )
331
332         logger(
333             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
334         )
335
336         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
337
338         if save_attention_image is not None:
339             for k in range(10):
340                 ns = torch.randint(self.test_input.size(0), (1,)).item()
341                 input = self.test_input[ns : ns + 1].clone()
342
343                 with torch.autograd.no_grad():
344                     t = model.training
345                     model.eval()
346                     # model.record_attention(True)
347                     model(BracketedSequence(input))
348                     model.train(t)
349                     # ram = model.retrieve_attention()
350                     # model.record_attention(False)
351
352                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
353                 # tokens_input = ["n/a"] + tokens_output[:-1]
354                 # for n_head in range(ram[0].size(1)):
355                 # filename = os.path.join(
356                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
357                 # )
358                 # attention_matrices = [m[0, n_head] for m in ram]
359                 # save_attention_image(
360                 # filename,
361                 # tokens_input,
362                 # tokens_output,
363                 # attention_matrices,
364                 # k_top=10,
365                 ##min_total_attention=0.9,
366                 # token_gap=12,
367                 # layer_gap=50,
368                 # )
369                 # logger(f"wrote {filename}")
370
371
372 ######################################################################
373
374 import picoclvr
375
376
377 class PicoCLVR(Task):
378     # Make a tensor from a list of strings
379     def tensorize(self, descr):
380         token_descr = [s.strip().split(" ") for s in descr]
381         l = max([len(s) for s in token_descr])
382         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
383         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
384         return torch.tensor(id_descr, device=self.device)
385
386     # Make a list of strings from a tensor
387     def detensorize(self, x):
388         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
389
390     # trim all the tensors in the tuple z to remove as much token from
391     # left and right in the first tensor. If z is a tuple, all its
392     # elements are trimed according to the triming for the first
393     def trim(self, z, token="<nul>"):
394         n = self.token2id[token]
395         if type(z) == tuple:
396             x = z[0]
397             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
398             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
399             return tuple([t[:, a:b] for t in z])
400         else:
401             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
402             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
403             return z[:, a:b]
404
405     ######################
406
407     def __init__(
408         self,
409         nb_train_samples,
410         nb_test_samples,
411         batch_size,
412         height,
413         width,
414         nb_colors=5,
415         logger=None,
416         device=torch.device("cpu"),
417         pruner_train=None,
418         pruner_eval=None,
419     ):
420         super().__init__()
421
422         def generate_descr(nb, cache_suffix, pruner):
423             return picoclvr.generate(
424                 nb,
425                 height=self.height,
426                 width=self.width,
427                 nb_colors=nb_colors,
428                 pruner=pruner,
429             )
430
431         self.height = height
432         self.width = width
433         self.batch_size = batch_size
434         self.device = device
435         self.pruner_train = pruner_train
436         self.pruner_eval = pruner_eval
437
438         if logger is not None:
439             logger(
440                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
441             )
442
443         self.train_descr = generate_descr(
444             nb_train_samples, "train", pruner=self.pruner_train
445         )
446         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
447
448         # Build the tokenizer
449         tokens = {"<nul>", "<img>"}
450         for d in [self.train_descr, self.test_descr]:
451             for s in d:
452                 for t in s.strip().split(" "):
453                     tokens.add(t)
454         # make this set a sorted list to get the same tensors given
455         # the same descr
456         tokens = list(tokens)
457         tokens.sort()
458         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
459         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
460         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
461
462         # Tokenize the train and test sets
463         self.train_input = self.tensorize(self.train_descr)
464         self.test_input = self.tensorize(self.test_descr)
465
466     def batches(self, split="train"):
467         assert split in {"train", "test"}
468         input = self.train_input if split == "train" else self.test_input
469         for batch in tqdm.tqdm(
470             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
471         ):
472             yield self.trim(batch)
473
474     def vocabulary_size(self):
475         return len(self.token2id)
476
477     def compute_missing_properties(
478         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
479     ):
480         acc_nb_requested_properties = []
481         acc_nb_missing_properties = []
482         acc_nb_results = 0
483
484         for input in tqdm.tqdm(
485             self.test_input.split(self.batch_size),
486             dynamic_ncols=True,
487             desc=f"test-properties",
488         ):
489             result = input.clone()
490             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
491             result = (1 - ar_mask) * result + ar_mask * self.t_nul
492             masked_inplace_autoregression(
493                 model,
494                 self.batch_size,
495                 result,
496                 ar_mask,
497                 deterministic_synthesis,
498                 progress_bar_desc=None,
499                 device=self.device,
500             )
501
502             result_descr = self.detensorize(result)
503             np = picoclvr.nb_properties(
504                 result_descr,
505                 height=self.height,
506                 width=self.width,
507                 pruner=pruner,
508             )
509             nb_requested_properties, _, nb_missing_properties = zip(*np)
510             acc_nb_requested_properties += nb_requested_properties
511             acc_nb_missing_properties += nb_missing_properties
512             acc_nb_results += len(result_descr)
513
514         nb_requested_properties = sum(acc_nb_requested_properties)
515         nb_missing_properties = sum(acc_nb_missing_properties)
516
517         prefix = "" if pruner is None else "pruned_"
518         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
519         logger(
520             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
521         )
522         logger(
523             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
524         )
525
526         logger(
527             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
528         )
529
530     ######################################################################
531
532     def produce_results(
533         self, n_epoch, model, result_dir, logger, deterministic_synthesis
534     ):
535         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
536
537         if self.pruner_eval is not None:
538             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
539
540         nb_tokens_to_generate = self.height * self.width + 3
541         result_descr = []
542         nb_per_primer = 8
543         primer = []
544
545         for primer_descr in [
546             "red above green <sep> green top <sep> blue right of red",
547             "there is red <sep> there is yellow <sep> there is blue",
548             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
549             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
550         ]:
551             primer += [primer_descr + " <img>"] * nb_per_primer
552
553         result = self.tensorize(primer)
554         fill = result.new_full(
555             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
556         )
557         result = torch.cat((result, fill), 1)
558         ar_mask = (result == self.t_nul).long()
559         masked_inplace_autoregression(
560             model,
561             self.batch_size,
562             result,
563             ar_mask,
564             deterministic_synthesis,
565             device=self.device,
566         )
567         result_descr = self.detensorize(result)
568
569         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
570
571         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
572         acc_nb_results = len(result_descr)
573
574         nb_requested_properties = sum(acc_nb_requested_properties)
575         nb_missing_properties = sum(acc_nb_missing_properties)
576
577         prefix = "demo_"
578         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
579         logger(
580             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
581         )
582         logger(
583             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
584         )
585
586         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
587
588         if img.dim() == 5:
589             if img.size(1) == 1:
590                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
591             else:
592                 img = torch.cat(
593                     [
594                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
595                         for x in img
596                     ],
597                     0,
598                 )
599
600         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
601         torchvision.utils.save_image(
602             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
603         )
604         logger(f"wrote {image_name}")
605
606
607 ######################################################################
608
609
610 class MNIST(Task):
611     def __init__(
612         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
613     ):
614         super().__init__()
615
616         self.nb_train_samples = (nb_train_samples,)
617         self.nb_test_samples = (nb_test_samples,)
618         self.batch_size = batch_size
619         self.device = device
620         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
621         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
622         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
623         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
624
625     def batches(self, split="train", nb_to_use=-1, desc=None):
626         assert split in {"train", "test"}
627         input = self.train_input if split == "train" else self.test_input
628         if nb_to_use > 0:
629             input = input[:nb_to_use]
630         if desc is None:
631             desc = f"epoch-{split}"
632         for batch in tqdm.tqdm(
633             input.split(self.batch_size), dynamic_ncols=True, desc=desc
634         ):
635             yield batch
636
637     def vocabulary_size(self):
638         return 256
639
640     def produce_results(
641         self, n_epoch, model, result_dir, logger, deterministic_synthesis
642     ):
643         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
644         ar_mask = torch.full_like(results, 1)
645         masked_inplace_autoregression(
646             model,
647             self.batch_size,
648             results,
649             ar_mask,
650             deterministic_synthesis,
651             device=self.device,
652         )
653         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
654         torchvision.utils.save_image(
655             1 - results.reshape(-1, 1, 28, 28) / 255.0,
656             image_name,
657             nrow=16,
658             pad_value=0.8,
659         )
660         logger(f"wrote {image_name}")
661
662
663 ######################################################################
664
665 import maze
666
667
668 class Maze(Task):
669     def map2seq(self, *m):
670         return torch.cat([x.flatten(1) for x in m], 1)
671
672     def seq2map(self, s):
673         s = s.reshape(s.size(0), -1, self.height, self.width)
674         return (s[:, k] for k in range(s.size(1)))
675
676     def __init__(
677         self,
678         nb_train_samples,
679         nb_test_samples,
680         batch_size,
681         height,
682         width,
683         nb_walls,
684         device=torch.device("cpu"),
685     ):
686         super().__init__()
687
688         self.batch_size = batch_size
689         self.height = height
690         self.width = width
691         self.device = device
692
693         train_mazes, train_paths, _ = maze.create_maze_data(
694             nb_train_samples,
695             height=height,
696             width=width,
697             nb_walls=nb_walls,
698             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
699         )
700         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
701
702         test_mazes, test_paths, _ = maze.create_maze_data(
703             nb_test_samples,
704             height=height,
705             width=width,
706             nb_walls=nb_walls,
707             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
708         )
709         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
710
711         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
712
713     def batches(self, split="train", nb_to_use=-1, desc=None):
714         assert split in {"train", "test"}
715         input = self.train_input if split == "train" else self.test_input
716         if nb_to_use > 0:
717             input = input[:nb_to_use]
718         if desc is None:
719             desc = f"epoch-{split}"
720         for batch in tqdm.tqdm(
721             input.split(self.batch_size), dynamic_ncols=True, desc=desc
722         ):
723             yield batch
724
725     def vocabulary_size(self):
726         return self.nb_codes
727
728     def compute_error(
729         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
730     ):
731         nb_total, nb_correct = 0, 0
732         count = torch.zeros(
733             self.width * self.height,
734             self.width * self.height,
735             device=self.device,
736             dtype=torch.int64,
737         )
738
739         for input in self.batches(split, nb_to_use):
740             result = input.clone()
741             ar_mask = result.new_zeros(result.size())
742             ar_mask[:, self.height * self.width :] = 1
743             result *= 1 - ar_mask
744             masked_inplace_autoregression(
745                 model,
746                 self.batch_size,
747                 result,
748                 ar_mask,
749                 deterministic_synthesis,
750                 progress_bar_desc=None,
751                 device=self.device,
752             )
753             mazes, paths = self.seq2map(result)
754             path_correctness = maze.path_correctness(mazes, paths)
755             nb_correct += path_correctness.long().sum()
756             nb_total += mazes.size(0)
757
758             optimal_path_lengths = (
759                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
760             )
761             predicted_path_lengths = (
762                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
763             )
764             optimal_path_lengths = optimal_path_lengths[path_correctness]
765             predicted_path_lengths = predicted_path_lengths[path_correctness]
766             count[optimal_path_lengths, predicted_path_lengths] += 1
767
768         if count.max() == 0:
769             count = None
770         else:
771             count = count[
772                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
773             ]
774
775         return nb_total, nb_correct, count
776
777     def produce_results(
778         self, n_epoch, model, result_dir, logger, deterministic_synthesis
779     ):
780         train_nb_total, train_nb_correct, count = self.compute_error(
781             model,
782             "train",
783             nb_to_use=1000,
784             deterministic_synthesis=deterministic_synthesis,
785         )
786         logger(
787             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
788         )
789
790         test_nb_total, test_nb_correct, count = self.compute_error(
791             model,
792             "test",
793             nb_to_use=1000,
794             deterministic_synthesis=deterministic_synthesis,
795         )
796         logger(
797             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
798         )
799
800         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
801
802         if count is not None:
803             proportion_optimal = count.diagonal().sum().float() / count.sum()
804             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
805             with open(
806                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
807             ) as f:
808                 for i in range(count.size(0)):
809                     for j in range(count.size(1)):
810                         eol = " " if j < count.size(1) - 1 else "\n"
811                         f.write(f"{count[i,j]}{eol}")
812
813         input = self.test_input[:48]
814         result = input.clone()
815         ar_mask = result.new_zeros(result.size())
816         ar_mask[:, self.height * self.width :] = 1
817         result *= 1 - ar_mask
818         masked_inplace_autoregression(
819             model,
820             self.batch_size,
821             result,
822             ar_mask,
823             deterministic_synthesis,
824             device=self.device,
825         )
826
827         mazes, paths = self.seq2map(input)
828         _, predicted_paths = self.seq2map(result)
829
830         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
831         maze.save_image(
832             filename,
833             mazes=mazes,
834             target_paths=paths,
835             predicted_paths=predicted_paths,
836             path_correct=maze.path_correctness(mazes, predicted_paths),
837             path_optimal=maze.path_optimality(paths, predicted_paths),
838         )
839         logger(f"wrote {filename}")
840
841
842 ######################################################################
843
844
845 import snake
846
847
848 class Snake(Task):
849     def __init__(
850         self,
851         nb_train_samples,
852         nb_test_samples,
853         batch_size,
854         height,
855         width,
856         nb_colors,
857         length,
858         prompt_length,
859         device=torch.device("cpu"),
860     ):
861         super().__init__()
862
863         self.batch_size = batch_size
864         self.height = height
865         self.width = width
866         self.device = device
867         self.prompt_length = prompt_length
868
869         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
870             nb_train_samples,
871             height,
872             width,
873             nb_colors,
874             length,
875             prompt_length,
876             self.device,
877         )
878         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
879             nb_test_samples,
880             height,
881             width,
882             nb_colors,
883             length,
884             prompt_length,
885             self.device,
886         )
887
888         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
889
890     def batches(self, split="train", nb_to_use=-1, desc=None):
891         assert split in {"train", "test"}
892         input = self.train_input if split == "train" else self.test_input
893         if nb_to_use > 0:
894             input = input[:nb_to_use]
895         if desc is None:
896             desc = f"epoch-{split}"
897         for batch in tqdm.tqdm(
898             input.split(self.batch_size), dynamic_ncols=True, desc=desc
899         ):
900             yield batch
901
902     def vocabulary_size(self):
903         return self.nb_codes
904
905     def produce_results(
906         self, n_epoch, model, result_dir, logger, deterministic_synthesis
907     ):
908         def compute_nb_correct(input, prior_visits):
909             result = input.clone()
910             i = torch.arange(result.size(1), device=result.device)[None, :]
911             ar_mask = (
912                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
913                 .long()
914                 .expand_as(result)
915             )
916             result *= 1 - ar_mask
917
918             masked_inplace_autoregression(
919                 model,
920                 self.batch_size,
921                 result,
922                 ar_mask,
923                 deterministic_synthesis,
924                 device=self.device,
925             )
926
927             nb_total = ((prior_visits > 0) * ar_mask).sum()
928
929             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
930
931             return nb_total, nb_correct
932
933         test_nb_total, test_nb_correct = compute_nb_correct(
934             self.test_input[:1000], self.test_prior_visits[:1000]
935         )
936
937         logger(
938             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
939         )
940
941         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
942
943
944 ######################################################################
945
946
947 import stack
948
949
950 class Stack(Task):
951     def __init__(
952         self,
953         nb_train_samples,
954         nb_test_samples,
955         batch_size,
956         logger,
957         nb_steps,
958         nb_stacks,
959         nb_digits,
960         fraction_values_for_train=None,
961         device=torch.device("cpu"),
962     ):
963         super().__init__()
964
965         self.batch_size = batch_size
966         self.nb_steps = nb_steps
967         self.nb_stacks = nb_stacks
968         self.nb_digits = nb_digits
969         self.device = device
970
971         if fraction_values_for_train is None:
972             values_for_train = None
973             values_for_test = None
974         else:
975             all = torch.randperm(10**nb_digits)
976             nb_for_train = int(all.size(0) * fraction_values_for_train)
977             values_for_train = all[:nb_for_train]
978             values_for_test = all[nb_for_train:]
979
980         self.train_input, self.train_stack_counts = stack.generate_sequences(
981             nb_train_samples,
982             nb_steps,
983             nb_stacks,
984             nb_digits,
985             values_for_train,
986             self.device,
987         )
988
989         self.test_input, self.test_stack_counts = stack.generate_sequences(
990             nb_test_samples,
991             nb_steps,
992             nb_stacks,
993             nb_digits,
994             values_for_test,
995             self.device,
996         )
997
998         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
999         counts = self.test_stack_counts.flatten()[i.flatten()]
1000         counts = F.one_hot(counts).sum(0)
1001         logger(f"test_pop_stack_counts {counts}")
1002
1003         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1004
1005     def batches(self, split="train", nb_to_use=-1, desc=None):
1006         assert split in {"train", "test"}
1007         input = self.train_input if split == "train" else self.test_input
1008         if nb_to_use > 0:
1009             input = input[:nb_to_use]
1010         if desc is None:
1011             desc = f"epoch-{split}"
1012         for batch in tqdm.tqdm(
1013             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1014         ):
1015             yield batch
1016
1017     def vocabulary_size(self):
1018         return self.nb_codes
1019
1020     def produce_results(
1021         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1022     ):
1023         def compute_nb_correct(input):
1024             result = input.clone()
1025             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1026             ar_mask = (result != input).long()
1027             masked_inplace_autoregression(
1028                 model,
1029                 self.batch_size,
1030                 result,
1031                 ar_mask,
1032                 deterministic_synthesis,
1033                 device=self.device,
1034             )
1035
1036             errors = ((result != input).long() * ar_mask).reshape(
1037                 -1, 1 + self.nb_digits
1038             )
1039             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1040
1041             nb_total = ar_mask.max(1).values.sum()
1042             nb_correct = nb_total - errors.max(1).values.sum()
1043
1044             return nb_total, nb_correct
1045
1046         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1047
1048         logger(
1049             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1050         )
1051
1052         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1053
1054         ##############################################################
1055         # Log a few generated sequences
1056         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1057         result = input.clone()
1058         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1059         ar_mask = (result != input).long()
1060
1061         # for n in range(result.size(0)):
1062         # logger(
1063         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1064         # )
1065
1066         masked_inplace_autoregression(
1067             model,
1068             self.batch_size,
1069             result,
1070             ar_mask,
1071             deterministic_synthesis,
1072             device=self.device,
1073         )
1074
1075         for n in range(result.size(0)):
1076             logger(
1077                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1078             )
1079         ##############################################################
1080
1081
1082 ######################################################################
1083
1084 import rpl
1085
1086
1087 class RPL(Task):
1088     def tensorize(self, sequences):
1089         len_max = max([len(x) for x in sequences])
1090         return torch.cat(
1091             [
1092                 torch.tensor(
1093                     [
1094                         [
1095                             self.token2id[str(c)]
1096                             for c in s + ["<nul>"] * (len_max - len(s))
1097                         ]
1098                         for s in sequences
1099                     ]
1100                 )
1101             ],
1102             0,
1103         )
1104
1105     def seq2str(self, seq):
1106         return " ".join([self.id2token[i] for i in seq])
1107
1108     def __init__(
1109         self,
1110         nb_train_samples,
1111         nb_test_samples,
1112         batch_size,
1113         nb_starting_values=3,
1114         max_input=9,
1115         prog_len=6,
1116         nb_runs=5,
1117         no_prog=False,
1118         logger=None,
1119         device=torch.device("cpu"),
1120     ):
1121         super().__init__()
1122
1123         self.batch_size = batch_size
1124         self.device = device
1125         self.no_prog = no_prog
1126
1127         train_sequences = [
1128             rpl.generate(
1129                 nb_starting_values=nb_starting_values,
1130                 nb_result_values_max=4 * nb_starting_values,
1131                 max_input=max_input,
1132                 prog_len=prog_len,
1133                 nb_runs=nb_runs,
1134             )
1135             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1136         ]
1137
1138         test_sequences = [
1139             rpl.generate(
1140                 nb_starting_values=nb_starting_values,
1141                 nb_result_values_max=4 * nb_starting_values,
1142                 max_input=max_input,
1143                 prog_len=prog_len,
1144                 nb_runs=nb_runs,
1145             )
1146             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1147         ]
1148
1149         symbols = list(
1150             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1151         )
1152         val_max = max([x if type(x) is int else 0 for x in symbols])
1153         symbols = list(filter(lambda x: type(x) is str, symbols))
1154         symbols.sort()
1155         symbols += [str(n) for n in range(val_max + 1)]
1156         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1157         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1158
1159         self.t_nul = self.token2id["<nul>"]
1160         self.t_input = self.token2id["<in>"]
1161         self.t_output = self.token2id["<out>"]
1162         self.t_prog = self.token2id["<prg>"]
1163         self.t_end = self.token2id["<end>"]
1164
1165         self.train_input = self.tensorize(train_sequences)
1166         self.test_input = self.tensorize(test_sequences)
1167
1168         if no_prog:
1169             # Excise the program from every train and test example
1170             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1171                 None, :
1172             ]
1173             p = (
1174                 ((self.train_input == self.t_prog).long() * k)
1175                 .max(1, keepdim=True)
1176                 .values
1177             )
1178             self.train_input = (
1179                 self.train_input * (k <= p).long()
1180                 + self.t_end * (k == p + 1).long()
1181                 + self.t_nul * (k > p + 1).long()
1182             )
1183             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1184                 None, :
1185             ]
1186             p = (
1187                 ((self.test_input == self.t_prog).long() * k)
1188                 .max(1, keepdim=True)
1189                 .values
1190             )
1191             self.test_input = (
1192                 self.test_input * (k <= p).long()
1193                 + self.t_end * (k == p + 1).long()
1194                 + self.t_nul * (k > p + 1).long()
1195             )
1196
1197         if logger is not None:
1198             logger(f"value_max {val_max}")
1199             for x in self.train_input[:25]:
1200                 end = (x != self.t_nul).nonzero().max().item() + 1
1201                 seq = [self.id2token[i.item()] for i in x[:end]]
1202                 s = " ".join(seq)
1203                 logger(f"example_seq {s}")
1204
1205         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1206
1207     def batches(self, split="train", nb_to_use=-1, desc=None):
1208         assert split in {"train", "test"}
1209         input = self.train_input if split == "train" else self.test_input
1210         if nb_to_use > 0:
1211             input = input[:nb_to_use]
1212         if desc is None:
1213             desc = f"epoch-{split}"
1214         for batch in tqdm.tqdm(
1215             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1216         ):
1217             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1218             batch = batch[:, :last].to(self.device)
1219             yield batch
1220
1221     def vocabulary_size(self):
1222         return self.nb_codes
1223
1224     def produce_results(
1225         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1226     ):
1227         # --------------------------------------------------------------------
1228         def compute_nb_errors_prog(input, nb_to_log=0):
1229             result = input.clone()
1230             s = (result == self.t_prog).long()
1231             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1232             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1233
1234             masked_inplace_autoregression(
1235                 model,
1236                 self.batch_size,
1237                 result,
1238                 ar_mask,
1239                 deterministic_synthesis,
1240                 device=self.device,
1241             )
1242
1243             sum_nb_total, sum_nb_errors = 0, 0
1244             for one_input, one_result in zip(input, result):
1245                 seq = [self.id2token[i.item()] for i in one_result]
1246                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1247                 sum_nb_total += 1
1248                 sum_nb_errors += 0 if nb_errors == 0 else 1
1249                 if nb_to_log > 0:
1250                     gt_seq = [self.id2token[i.item()] for i in one_input]
1251                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1252                     gt_prog = " ".join([str(x) for x in gt_prog])
1253                     prog = " ".join([str(x) for x in prog])
1254                     comment = "*" if nb_errors == 0 else "-"
1255                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1256                     for start_stack, target_stack, result_stack, correct in stacks:
1257                         comment = "*" if correct else "-"
1258                         start_stack = " ".join([str(x) for x in start_stack])
1259                         target_stack = " ".join([str(x) for x in target_stack])
1260                         result_stack = " ".join([str(x) for x in result_stack])
1261                         logger(
1262                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1263                         )
1264                     nb_to_log -= 1
1265
1266             return sum_nb_total, sum_nb_errors
1267
1268         # --------------------------------------------------------------------
1269         def compute_nb_errors_output(input, nb_to_log=0):
1270             result = input.clone()
1271             k = torch.arange(result.size(1), device=result.device)[None, :]
1272             last_output_idx = (
1273                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1274             )
1275             first_prog_idx = (
1276                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1277             )
1278             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1279             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1280
1281             masked_inplace_autoregression(
1282                 model,
1283                 self.batch_size,
1284                 result,
1285                 ar_mask,
1286                 deterministic_synthesis,
1287                 device=self.device,
1288             )
1289
1290             sum_nb_total, sum_nb_errors = 0, 0
1291             for one_input, one_result, i, j in zip(
1292                 input, result, last_output_idx, first_prog_idx
1293             ):
1294                 seq = [self.id2token[i.item()] for i in one_result]
1295                 sum_nb_total += 1
1296                 correct = (one_input - one_result).abs().max() == 0
1297                 sum_nb_errors += 0 if correct else 1
1298                 if nb_to_log > 0:
1299                     result_stack = [
1300                         self.id2token[i.item()] for i in one_result[i : j + 1]
1301                     ]
1302                     target_stack = [
1303                         self.id2token[i.item()] for i in one_input[i : j + 1]
1304                     ]
1305                     comment = "*" if correct else "-"
1306                     result_stack = " ".join([str(x) for x in result_stack])
1307                     target_stack = " ".join([str(x) for x in target_stack])
1308                     logger(
1309                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1310                     )
1311                     nb_to_log -= 1
1312
1313             return sum_nb_total, sum_nb_errors
1314
1315         # --------------------------------------------------------------------
1316
1317         if not self.no_prog:
1318             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1319                 self.test_input[:1000].to(self.device), nb_to_log=10
1320             )
1321
1322             logger(
1323                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1324             )
1325
1326             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1327
1328         test_nb_total, test_nb_errors = compute_nb_errors_output(
1329             self.test_input[:1000].to(self.device), nb_to_log=10
1330         )
1331
1332         logger(
1333             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1334         )
1335
1336         if save_attention_image is None:
1337             logger("no save_attention_image (is pycairo installed?)")
1338         else:
1339             ns = torch.randint(self.test_input.size(0), (1,)).item()
1340             input = self.test_input[ns : ns + 1].clone()
1341             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1342             input = input[:, :last].to(self.device)
1343
1344             with torch.autograd.no_grad():
1345                 t = model.training
1346                 model.eval()
1347                 model.record_attention(True)
1348                 model(BracketedSequence(input))
1349                 model.train(t)
1350                 ram = model.retrieve_attention()
1351                 model.record_attention(False)
1352
1353             tokens_output = [self.id2token[i.item()] for i in input[0]]
1354             tokens_input = ["n/a"] + tokens_output[:-1]
1355             for n_head in range(ram[0].size(1)):
1356                 filename = os.path.join(
1357                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1358                 )
1359                 attention_matrices = [m[0, n_head] for m in ram]
1360                 save_attention_image(
1361                     filename,
1362                     tokens_input,
1363                     tokens_output,
1364                     attention_matrices,
1365                     k_top=10,
1366                     # min_total_attention=0.9,
1367                     token_gap=12,
1368                     layer_gap=50,
1369                 )
1370                 logger(f"wrote {filename}")
1371
1372
1373 ######################################################################
1374
1375
1376 import expr
1377
1378
1379 class Expr(Task):
1380     def tensorize(self, sequences):
1381         len_max = max([len(x) for x in sequences])
1382         return torch.cat(
1383             [
1384                 torch.tensor(
1385                     [
1386                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1387                         for s in sequences
1388                     ]
1389                 )
1390             ],
1391             0,
1392         ).to(self.device)
1393
1394     def __init__(
1395         self,
1396         nb_train_samples,
1397         nb_test_samples,
1398         nb_variables,
1399         sequence_length,
1400         operand_max,
1401         result_max,
1402         batch_size,
1403         device=torch.device("cpu"),
1404     ):
1405         super().__init__()
1406
1407         self.batch_size = batch_size
1408         self.device = device
1409
1410         train_sequences = expr.generate_sequences(
1411             nb_train_samples,
1412             nb_variables=nb_variables,
1413             length=sequence_length,
1414             operand_max=operand_max,
1415             result_max=result_max,
1416         )
1417
1418         test_sequences = expr.generate_sequences(
1419             nb_test_samples,
1420             nb_variables=nb_variables,
1421             length=sequence_length,
1422             operand_max=operand_max,
1423             result_max=result_max,
1424         )
1425
1426         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1427         symbols.sort()
1428
1429         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1430         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1431
1432         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1433
1434         self.train_input = self.tensorize(train_sequences)
1435         self.test_input = self.tensorize(test_sequences)
1436
1437         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1438
1439     def batches(self, split="train", nb_to_use=-1, desc=None):
1440         assert split in {"train", "test"}
1441         input = self.train_input if split == "train" else self.test_input
1442         if nb_to_use > 0:
1443             input = input[:nb_to_use]
1444         if desc is None:
1445             desc = f"epoch-{split}"
1446         for batch in tqdm.tqdm(
1447             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1448         ):
1449             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1450             batch = batch[:, :last]
1451             yield batch
1452
1453     def vocabulary_size(self):
1454         return self.nb_codes
1455
1456     def seq2str(self, s):
1457         return "".join([self.id2char[k.item()] for k in s])
1458
1459     def produce_results(
1460         self,
1461         n_epoch,
1462         model,
1463         result_dir,
1464         logger,
1465         deterministic_synthesis,
1466         input_file=None,
1467     ):
1468         def compute_nb_correct(input):
1469             result = input.clone()
1470             s = (result == self.space).long()
1471             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1472             result = (1 - ar_mask) * result + ar_mask * self.filler
1473             masked_inplace_autoregression(
1474                 model,
1475                 self.batch_size,
1476                 result,
1477                 ar_mask,
1478                 deterministic_synthesis,
1479                 device=self.device,
1480             )
1481
1482             nb_total = input.size(0)
1483             nb_correct = (input == result).long().min(1).values.sum()
1484
1485             #######################################################################
1486             # Comput predicted vs. true variable values
1487
1488             nb_delta = torch.zeros(5, dtype=torch.int64)
1489             nb_missed = 0
1490
1491             values_input = expr.extract_results([self.seq2str(s) for s in input])
1492             values_result = expr.extract_results([self.seq2str(s) for s in result])
1493
1494             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1495
1496             with open(filename, "w") as f:
1497                 for i, r in zip(values_input, values_result):
1498                     for n, vi in i.items():
1499                         vr = r.get(n)
1500                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1501
1502                         if vr is None or vr < 0:
1503                             nb_missed += 1
1504                         else:
1505                             d = abs(vr - vi)
1506                             if d >= nb_delta.size(0):
1507                                 nb_missed += 1
1508                             else:
1509                                 nb_delta[d] += 1
1510
1511             ######################################################################
1512
1513             return nb_total, nb_correct, nb_delta, nb_missed
1514
1515         (
1516             test_nb_total,
1517             test_nb_correct,
1518             test_nb_delta,
1519             test_nb_missed,
1520         ) = compute_nb_correct(self.test_input[:10000])
1521
1522         logger(
1523             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1524         )
1525
1526         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1527
1528         nb_total = test_nb_delta.sum() + test_nb_missed
1529         for d in range(test_nb_delta.size(0)):
1530             logger(
1531                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1532             )
1533         logger(
1534             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1535         )
1536
1537         ##############################################################
1538         # Log a few generated sequences
1539         if input_file is None:
1540             input = self.test_input[:10]
1541         else:
1542             with open(input_file, "r") as f:
1543                 sequences = [e.strip() for e in f.readlines()]
1544                 sequences = [s + " " + "#" * 50 for s in sequences]
1545                 input = self.tensorize(sequences)
1546
1547         result = input.clone()
1548         s = (result == self.space).long()
1549         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1550         result = (1 - ar_mask) * result + ar_mask * self.filler
1551
1552         for n in range(result.size(0)):
1553             logger(f"test_before {self.seq2str(result[n])}")
1554
1555         masked_inplace_autoregression(
1556             model,
1557             self.batch_size,
1558             result,
1559             ar_mask,
1560             deterministic_synthesis,
1561             device=self.device,
1562         )
1563
1564         correct = (1 - ar_mask) * self.space + ar_mask * input
1565         for n in range(result.size(0)):
1566             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1567             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1568             logger(f"truth       {self.seq2str(correct[n])}")
1569         ##############################################################
1570
1571
1572 ######################################################################
1573
1574 import grid
1575
1576
1577 class Grid(Task):
1578     # Make a tensor from a list of strings
1579     def str2tensor(self, descr):
1580         token_descr = [s.strip().split(" ") for s in descr]
1581         l = max([len(s) for s in token_descr])
1582         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1583         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1584         return torch.tensor(id_descr, device=self.device)
1585
1586     # Make a list of strings from a tensor
1587     def tensor2str(self, x):
1588         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1589
1590     # trim all the tensors in the tuple z to remove as much token from
1591     # left and right in the first tensor. If z is a tuple, all its
1592     # elements are trimed according to the triming for the first
1593     def trim(self, z, token="#"):
1594         n = self.token2id[token]
1595         if type(z) == tuple:
1596             x = z[0]
1597             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1598             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1599             return tuple([t[:, a:b] for t in z])
1600         else:
1601             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1602             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1603             return z[:, a:b]
1604
1605     ######################
1606
1607     def __init__(
1608         self,
1609         nb_train_samples,
1610         nb_test_samples,
1611         batch_size,
1612         size,
1613         fraction_play=0.0,
1614         logger=None,
1615         device=torch.device("cpu"),
1616     ):
1617         super().__init__()
1618
1619         self.device = device
1620         self.batch_size = batch_size
1621         self.grid_factory = grid.GridFactory(size=size)
1622         self.fraction_play = fraction_play
1623
1624         if logger is not None:
1625             logger(
1626                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1627             )
1628
1629         self.train_descr = self.grid_factory.generate_samples(
1630             nb=nb_train_samples,
1631             fraction_play=fraction_play,
1632             progress_bar=lambda r: tqdm.tqdm(r),
1633         )
1634
1635         self.test_descr = self.grid_factory.generate_samples(
1636             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1637         )
1638
1639         if fraction_play > 0:
1640             self.play_descr = self.grid_factory.generate_samples(
1641                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1642             )
1643         else:
1644             self.play_descr = []
1645
1646         # Build the tokenizer
1647         tokens = set()
1648         for d in [self.train_descr, self.test_descr, self.play_descr]:
1649             for s in d:
1650                 for t in s.strip().split(" "):
1651                     tokens.add(t)
1652         # make this set a sorted list to get the same tensors given
1653         # the same descr
1654         tokens = list(tokens)
1655         tokens.sort()
1656         tokens = ["#"] + tokens
1657         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1658         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1659         self.t_nul = self.token2id["#"]
1660         self.t_true = self.token2id["true"]
1661         self.t_false = self.token2id["false"]
1662         self.t_pipe = self.token2id["|"]
1663
1664         # Tokenize the train and test sets
1665         self.train_input = self.str2tensor(self.train_descr)
1666         self.test_input = self.str2tensor(self.test_descr)
1667         self.play_input = (
1668             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1669         )
1670
1671     def batches(self, split="train"):
1672         assert split in {"train", "test"}
1673         input = self.train_input if split == "train" else self.test_input
1674         for batch in tqdm.tqdm(
1675             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1676         ):
1677             yield self.trim(batch)
1678
1679     def vocabulary_size(self):
1680         return len(self.token2id)
1681
1682     def produce_results(
1683         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1684     ):
1685         correct = self.test_input[:1000]
1686         result = correct.clone()
1687         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1688         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1689
1690         logger(f"----------------------------------------------------------")
1691
1692         for e in self.tensor2str(result[:10]):
1693             logger(f"test_before {e}")
1694
1695         masked_inplace_autoregression(
1696             model,
1697             self.batch_size,
1698             result,
1699             ar_mask,
1700             deterministic_synthesis,
1701             device=self.device,
1702         )
1703
1704         logger(f"----------------------------------------------------------")
1705
1706         for e in self.tensor2str(result[:10]):
1707             logger(f"test_after  {e}")
1708
1709         logger(f"----------------------------------------------------------")
1710
1711         nb_total = ar_mask.sum().item()
1712         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1713
1714         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1715         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1716
1717         if self.play_input is not None:
1718             result = self.play_input.clone()
1719             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1720             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1721
1722             logger(f"----------------------------------------------------------")
1723
1724             for e in self.tensor2str(result[:10]):
1725                 logger(f"play_before {e}")
1726
1727             masked_inplace_autoregression(
1728                 model,
1729                 self.batch_size,
1730                 result,
1731                 ar_mask,
1732                 deterministic_synthesis,
1733                 device=self.device,
1734             )
1735
1736             logger(f"----------------------------------------------------------")
1737
1738             for e in self.tensor2str(result[:10]):
1739                 logger(f"play_after  {e}")
1740
1741             logger(f"----------------------------------------------------------")
1742
1743
1744 ######################################################################
1745
1746 import qmlp
1747
1748
1749 class QMLP(Task):
1750     ######################
1751
1752     def __init__(
1753         self,
1754         nb_train_samples,
1755         nb_test_samples,
1756         batch_size,
1757         result_dir,
1758         logger=None,
1759         device=torch.device("cpu"),
1760     ):
1761         super().__init__()
1762
1763         self.device = device
1764         self.batch_size = batch_size
1765         self.nb_samples_per_mlp = 256
1766
1767         if logger is not None:
1768             logger(
1769                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1770             )
1771
1772         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1773             nb_mlps=nb_train_samples + nb_test_samples,
1774             nb_samples=self.nb_samples_per_mlp,
1775             device=self.device,
1776             batch_size=64,
1777             nb_epochs=250,
1778             nb_mlps_per_batch=1024,
1779         )
1780
1781         self.train_input = seq[:nb_train_samples]
1782         self.train_q_test_set = q_test_set[:nb_train_samples]
1783         self.train_ref_test_errors = test_error[:nb_train_samples]
1784         self.test_input = seq[nb_train_samples:]
1785         self.test_q_test_set = q_test_set[nb_train_samples:]
1786         self.test_ref_test_errors = test_error[nb_train_samples:]
1787
1788         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1789         with open(filename, "w") as f:
1790             for e in self.train_ref_test_errors:
1791                 f.write(f"{e}\n")
1792
1793         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1794         with open(filename, "w") as f:
1795             for e in self.test_ref_test_errors:
1796                 f.write(f"{e}\n")
1797
1798         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1799
1800     def batches(self, split="train"):
1801         assert split in {"train", "test"}
1802         input = self.train_input if split == "train" else self.test_input
1803         for batch in tqdm.tqdm(
1804             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1805         ):
1806             yield batch
1807
1808     def vocabulary_size(self):
1809         return self.nb_codes
1810
1811     def produce_results(
1812         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1813     ):
1814         correct = self.test_input[:1000]
1815         result = correct.clone()
1816         ar_mask = (
1817             torch.arange(result.size(1), device=result.device)
1818             > self.nb_samples_per_mlp * 3 + 1
1819         ).long()[None, :]
1820         ar_mask = ar_mask.expand_as(result)
1821         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1822
1823         masked_inplace_autoregression(
1824             model,
1825             self.batch_size,
1826             result,
1827             ar_mask,
1828             deterministic_synthesis,
1829             device=self.device,
1830         )
1831
1832         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1833         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1834         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1835
1836         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1837         with open(filename, "w") as f:
1838             for e in error_test:
1839                 f.write(f"{e}\n")
1840
1841
1842 ######################################################################