e5d3a7e66a77c769d970da8c0268f7fee307b7b9
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train"):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 class TaskFromFile(Task):
74     def tensorize(self, pairs):
75         len_max = max([len(x[0]) for x in pairs])
76
77         input = torch.cat(
78             [
79                 torch.tensor(
80                     [
81                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
82                         for s in pairs
83                     ]
84                 )
85             ],
86             0,
87         ).to("cpu")
88
89         pred_mask = torch.cat(
90             [
91                 torch.tensor(
92                     [
93                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
94                         for s in pairs
95                     ]
96                 )
97             ],
98             0,
99         ).to("cpu")
100
101         return input, pred_mask
102
103     # trim all the tensors in the tuple z to remove as much token from
104     # left and right in the first tensor. If z is a tuple, all its
105     # elements are trimed according to the triming for the first
106     def trim(self, z, token="#"):
107         n = self.char2id[token]
108         if type(z) == tuple:
109             x = z[0]
110             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
111             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
112             return tuple([t[:, a:b] for t in z])
113         else:
114             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
115             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
116             return z[:, a:b]
117
118     def __init__(
119         self,
120         train_filename,
121         test_filename,
122         nb_train_samples,
123         nb_test_samples,
124         batch_size,
125         device=torch.device("cpu"),
126     ):
127         self.batch_size = batch_size
128         self.device = device
129
130         def read_file(filename, nb=-1):
131             pairs = []
132             with open(filename, "r") as f:
133                 while True:
134                     sequence = f.readline().strip()
135                     if not sequence:
136                         break
137                     pred_mask = f.readline().strip()
138                     assert len(sequence) == len(pred_mask)
139                     assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
140                     pairs.append((sequence, pred_mask))
141                     if len(pairs) == nb:
142                         break
143
144             if nb > 0:
145                 pairs = pairs[:nb]
146                 assert len(pairs) == nb
147
148             return pairs
149
150         train_pairs = read_file(train_filename, nb_train_samples)
151         test_pairs = read_file(test_filename, nb_test_samples)
152
153         symbols = ["#"] + list(
154             set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
155         )
156         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
157         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
158
159         self.train_input, self.train_pred_masks = self.tensorize(train_pairs)
160         self.test_input, self.test_pred_masks = self.tensorize(test_pairs)
161
162     def batches(self, split="train", nb_to_use=-1, desc=None):
163         assert split in {"train", "test"}
164         input = self.train_input if split == "train" else self.test_input
165         if nb_to_use > 0:
166             input = input[:nb_to_use]
167         if desc is None:
168             desc = f"epoch-{split}"
169         for batch in tqdm.tqdm(
170             input.split(self.batch_size), dynamic_ncols=True, desc=desc
171         ):
172             yield self.trim(batch).to(self.device)
173
174     def vocabulary_size(self):
175         return len(self.char2id)
176
177     def tensor2str(self, t):
178         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
179
180     def produce_results(
181         self, n_epoch, model, result_dir, logger, deterministic_synthesis
182     ):
183         correct = self.trim(self.test_input[:1000]).to(self.device)
184         result = correct.clone()
185         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
186         ar_mask = (pred_mask > 0).long()
187         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
188
189         logger(f"----------------------------------------------------------")
190
191         for e in self.tensor2str(result[:50]):
192             logger(f"test_before {e}")
193
194         masked_inplace_autoregression(
195             model,
196             self.batch_size,
197             result,
198             ar_mask,
199             deterministic_synthesis,
200             device=self.device,
201         )
202
203         logger(f"----------------------------------------------------------")
204
205         for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
206             logger(f"test_after  {e}")
207             logger(f"correct     {c}")
208
209         logger(f"----------------------------------------------------------")
210
211         err_mask = (pred_mask == 2).long()
212         nb_total = err_mask.sum().item()
213         nb_correct = ((correct == result).long() * err_mask).sum().item()
214
215         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
216         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
217
218
219 ####################
220
221 import problems
222
223
224 class SandBox(Task):
225     def __init__(
226         self,
227         problem,
228         nb_train_samples,
229         nb_test_samples,
230         batch_size,
231         logger=None,
232         device=torch.device("cpu"),
233         max_nb_codes=1024,
234     ):
235         super().__init__()
236
237         self.batch_size = batch_size
238         self.device = device
239         self.problem = problem
240
241         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
242             nb_train_samples
243         )
244         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
245             nb_test_samples
246         )
247
248         self.train_input, self.train_ar_mask = self.train_input.to(
249             device
250         ), self.train_ar_mask.to(device)
251         self.test_input, self.test_ar_mask = self.test_input.to(
252             device
253         ), self.test_ar_mask.to(device)
254
255         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
256
257         # A bit of paranoia never hurts
258         assert self.nb_codes <= max_nb_codes
259         assert self.train_input.min() >= 0
260         assert self.test_input.min() >= 0
261         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
262             (0,),
263             (1,),
264             (0, 1),
265         }
266         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
267             (0,),
268             (1,),
269             (0, 1),
270         }
271
272         if logger is not None:
273             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
274                 logger(f"train_sequences {self.problem.seq2str(s)}")
275                 a = "".join(["01"[x.item()] for x in a])
276                 logger(f"                {a}")
277
278     def batches(self, split="train", nb_to_use=-1, desc=None):
279         assert split in {"train", "test"}
280         input = self.train_input if split == "train" else self.test_input
281         if nb_to_use > 0:
282             input = input[:nb_to_use]
283         if desc is None:
284             desc = f"epoch-{split}"
285         for batch in tqdm.tqdm(
286             input.split(self.batch_size), dynamic_ncols=True, desc=desc
287         ):
288             yield batch
289
290     def vocabulary_size(self):
291         return self.nb_codes
292
293     def produce_results(
294         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
295     ):
296         def compute_accuracy(input, ar_mask, logger=None):
297             input, ar_mask = input[:nmax], ar_mask[:nmax]
298             result = input.clone() * (1 - ar_mask)
299
300             masked_inplace_autoregression(
301                 model,
302                 self.batch_size,
303                 result,
304                 ar_mask,
305                 deterministic_synthesis,
306                 progress_bar_desc=None,
307                 device=self.device,
308             )
309
310             log_ground_truth = ar_mask.min() == 0
311
312             if logger is not None:
313                 for sp, st in zip(result[:10], input[:10]):
314                     logger(
315                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
316                     )
317                     if log_ground_truth:
318                         logger(
319                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
320                         )
321
322             nb_total, nb_correct = self.problem.compute_nb_correct(
323                 input, ar_mask, result
324             )
325
326             # nb_total = ar_mask.sum().item()
327             # nb_correct = ((result == input).long() * ar_mask).sum().item()
328
329             return nb_total, nb_correct
330
331         train_nb_total, train_nb_correct = compute_accuracy(
332             self.train_input, self.train_ar_mask
333         )
334
335         logger(
336             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
337         )
338
339         test_nb_total, test_nb_correct = compute_accuracy(
340             self.test_input, self.test_ar_mask, logger
341         )
342
343         logger(
344             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
345         )
346
347         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
348
349         if save_attention_image is not None:
350             for k in range(10):
351                 ns = torch.randint(self.test_input.size(0), (1,)).item()
352                 input = self.test_input[ns : ns + 1].clone()
353
354                 with torch.autograd.no_grad():
355                     t = model.training
356                     model.eval()
357                     # model.record_attention(True)
358                     model(BracketedSequence(input))
359                     model.train(t)
360                     # ram = model.retrieve_attention()
361                     # model.record_attention(False)
362
363                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
364                 # tokens_input = ["n/a"] + tokens_output[:-1]
365                 # for n_head in range(ram[0].size(1)):
366                 # filename = os.path.join(
367                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
368                 # )
369                 # attention_matrices = [m[0, n_head] for m in ram]
370                 # save_attention_image(
371                 # filename,
372                 # tokens_input,
373                 # tokens_output,
374                 # attention_matrices,
375                 # k_top=10,
376                 ##min_total_attention=0.9,
377                 # token_gap=12,
378                 # layer_gap=50,
379                 # )
380                 # logger(f"wrote {filename}")
381
382
383 ######################################################################
384
385 import picoclvr
386
387
388 class PicoCLVR(Task):
389     # Make a tensor from a list of strings
390     def tensorize(self, descr):
391         token_descr = [s.strip().split(" ") for s in descr]
392         l = max([len(s) for s in token_descr])
393         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
394         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
395         return torch.tensor(id_descr, device=self.device)
396
397     # Make a list of strings from a tensor
398     def detensorize(self, x):
399         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
400
401     # trim all the tensors in the tuple z to remove as much token from
402     # left and right in the first tensor. If z is a tuple, all its
403     # elements are trimed according to the triming for the first
404     def trim(self, z, token="<nul>"):
405         n = self.token2id[token]
406         if type(z) == tuple:
407             x = z[0]
408             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
409             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
410             return tuple([t[:, a:b] for t in z])
411         else:
412             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
413             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
414             return z[:, a:b]
415
416     ######################
417
418     def __init__(
419         self,
420         nb_train_samples,
421         nb_test_samples,
422         batch_size,
423         height,
424         width,
425         nb_colors=5,
426         logger=None,
427         device=torch.device("cpu"),
428         pruner_train=None,
429         pruner_eval=None,
430     ):
431         super().__init__()
432
433         def generate_descr(nb, cache_suffix, pruner):
434             return picoclvr.generate(
435                 nb,
436                 height=self.height,
437                 width=self.width,
438                 nb_colors=nb_colors,
439                 pruner=pruner,
440             )
441
442         self.height = height
443         self.width = width
444         self.batch_size = batch_size
445         self.device = device
446         self.pruner_train = pruner_train
447         self.pruner_eval = pruner_eval
448
449         if logger is not None:
450             logger(
451                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
452             )
453
454         self.train_descr = generate_descr(
455             nb_train_samples, "train", pruner=self.pruner_train
456         )
457         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
458
459         # Build the tokenizer
460         tokens = {"<nul>", "<img>"}
461         for d in [self.train_descr, self.test_descr]:
462             for s in d:
463                 for t in s.strip().split(" "):
464                     tokens.add(t)
465         # make this set a sorted list to get the same tensors given
466         # the same descr
467         tokens = list(tokens)
468         tokens.sort()
469         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
470         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
471         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
472
473         # Tokenize the train and test sets
474         self.train_input = self.tensorize(self.train_descr)
475         self.test_input = self.tensorize(self.test_descr)
476
477     def batches(self, split="train"):
478         assert split in {"train", "test"}
479         input = self.train_input if split == "train" else self.test_input
480         for batch in tqdm.tqdm(
481             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
482         ):
483             yield self.trim(batch)
484
485     def vocabulary_size(self):
486         return len(self.token2id)
487
488     def compute_missing_properties(
489         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
490     ):
491         acc_nb_requested_properties = []
492         acc_nb_missing_properties = []
493         acc_nb_results = 0
494
495         for input in tqdm.tqdm(
496             self.test_input.split(self.batch_size),
497             dynamic_ncols=True,
498             desc=f"test-properties",
499         ):
500             result = input.clone()
501             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
502             result = (1 - ar_mask) * result + ar_mask * self.t_nul
503             masked_inplace_autoregression(
504                 model,
505                 self.batch_size,
506                 result,
507                 ar_mask,
508                 deterministic_synthesis,
509                 progress_bar_desc=None,
510                 device=self.device,
511             )
512
513             result_descr = self.detensorize(result)
514             np = picoclvr.nb_properties(
515                 result_descr,
516                 height=self.height,
517                 width=self.width,
518                 pruner=pruner,
519             )
520             nb_requested_properties, _, nb_missing_properties = zip(*np)
521             acc_nb_requested_properties += nb_requested_properties
522             acc_nb_missing_properties += nb_missing_properties
523             acc_nb_results += len(result_descr)
524
525         nb_requested_properties = sum(acc_nb_requested_properties)
526         nb_missing_properties = sum(acc_nb_missing_properties)
527
528         prefix = "" if pruner is None else "pruned_"
529         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
530         logger(
531             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
532         )
533         logger(
534             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
535         )
536
537         logger(
538             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
539         )
540
541     ######################################################################
542
543     def produce_results(
544         self, n_epoch, model, result_dir, logger, deterministic_synthesis
545     ):
546         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
547
548         if self.pruner_eval is not None:
549             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
550
551         nb_tokens_to_generate = self.height * self.width + 3
552         result_descr = []
553         nb_per_primer = 8
554         primer = []
555
556         for primer_descr in [
557             "red above green <sep> green top <sep> blue right of red",
558             "there is red <sep> there is yellow <sep> there is blue",
559             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
560             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
561         ]:
562             primer += [primer_descr + " <img>"] * nb_per_primer
563
564         result = self.tensorize(primer)
565         fill = result.new_full(
566             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
567         )
568         result = torch.cat((result, fill), 1)
569         ar_mask = (result == self.t_nul).long()
570         masked_inplace_autoregression(
571             model,
572             self.batch_size,
573             result,
574             ar_mask,
575             deterministic_synthesis,
576             device=self.device,
577         )
578         result_descr = self.detensorize(result)
579
580         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
581
582         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
583         acc_nb_results = len(result_descr)
584
585         nb_requested_properties = sum(acc_nb_requested_properties)
586         nb_missing_properties = sum(acc_nb_missing_properties)
587
588         prefix = "demo_"
589         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
590         logger(
591             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
592         )
593         logger(
594             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
595         )
596
597         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
598
599         if img.dim() == 5:
600             if img.size(1) == 1:
601                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
602             else:
603                 img = torch.cat(
604                     [
605                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
606                         for x in img
607                     ],
608                     0,
609                 )
610
611         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
612         torchvision.utils.save_image(
613             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
614         )
615         logger(f"wrote {image_name}")
616
617
618 ######################################################################
619
620
621 class MNIST(Task):
622     def __init__(
623         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
624     ):
625         super().__init__()
626
627         self.nb_train_samples = (nb_train_samples,)
628         self.nb_test_samples = (nb_test_samples,)
629         self.batch_size = batch_size
630         self.device = device
631         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
632         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
633         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
634         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
635
636     def batches(self, split="train", nb_to_use=-1, desc=None):
637         assert split in {"train", "test"}
638         input = self.train_input if split == "train" else self.test_input
639         if nb_to_use > 0:
640             input = input[:nb_to_use]
641         if desc is None:
642             desc = f"epoch-{split}"
643         for batch in tqdm.tqdm(
644             input.split(self.batch_size), dynamic_ncols=True, desc=desc
645         ):
646             yield batch
647
648     def vocabulary_size(self):
649         return 256
650
651     def produce_results(
652         self, n_epoch, model, result_dir, logger, deterministic_synthesis
653     ):
654         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
655         ar_mask = torch.full_like(results, 1)
656         masked_inplace_autoregression(
657             model,
658             self.batch_size,
659             results,
660             ar_mask,
661             deterministic_synthesis,
662             device=self.device,
663         )
664         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
665         torchvision.utils.save_image(
666             1 - results.reshape(-1, 1, 28, 28) / 255.0,
667             image_name,
668             nrow=16,
669             pad_value=0.8,
670         )
671         logger(f"wrote {image_name}")
672
673
674 ######################################################################
675
676 import maze
677
678
679 class Maze(Task):
680     def map2seq(self, *m):
681         return torch.cat([x.flatten(1) for x in m], 1)
682
683     def seq2map(self, s):
684         s = s.reshape(s.size(0), -1, self.height, self.width)
685         return (s[:, k] for k in range(s.size(1)))
686
687     def __init__(
688         self,
689         nb_train_samples,
690         nb_test_samples,
691         batch_size,
692         height,
693         width,
694         nb_walls,
695         device=torch.device("cpu"),
696     ):
697         super().__init__()
698
699         self.batch_size = batch_size
700         self.height = height
701         self.width = width
702         self.device = device
703
704         train_mazes, train_paths, _ = maze.create_maze_data(
705             nb_train_samples,
706             height=height,
707             width=width,
708             nb_walls=nb_walls,
709             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
710         )
711         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
712
713         test_mazes, test_paths, _ = maze.create_maze_data(
714             nb_test_samples,
715             height=height,
716             width=width,
717             nb_walls=nb_walls,
718             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
719         )
720         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
721
722         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
723
724     def batches(self, split="train", nb_to_use=-1, desc=None):
725         assert split in {"train", "test"}
726         input = self.train_input if split == "train" else self.test_input
727         if nb_to_use > 0:
728             input = input[:nb_to_use]
729         if desc is None:
730             desc = f"epoch-{split}"
731         for batch in tqdm.tqdm(
732             input.split(self.batch_size), dynamic_ncols=True, desc=desc
733         ):
734             yield batch
735
736     def vocabulary_size(self):
737         return self.nb_codes
738
739     def compute_error(
740         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
741     ):
742         nb_total, nb_correct = 0, 0
743         count = torch.zeros(
744             self.width * self.height,
745             self.width * self.height,
746             device=self.device,
747             dtype=torch.int64,
748         )
749
750         for input in self.batches(split, nb_to_use):
751             result = input.clone()
752             ar_mask = result.new_zeros(result.size())
753             ar_mask[:, self.height * self.width :] = 1
754             result *= 1 - ar_mask
755             masked_inplace_autoregression(
756                 model,
757                 self.batch_size,
758                 result,
759                 ar_mask,
760                 deterministic_synthesis,
761                 progress_bar_desc=None,
762                 device=self.device,
763             )
764             mazes, paths = self.seq2map(result)
765             path_correctness = maze.path_correctness(mazes, paths)
766             nb_correct += path_correctness.long().sum()
767             nb_total += mazes.size(0)
768
769             optimal_path_lengths = (
770                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
771             )
772             predicted_path_lengths = (
773                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
774             )
775             optimal_path_lengths = optimal_path_lengths[path_correctness]
776             predicted_path_lengths = predicted_path_lengths[path_correctness]
777             count[optimal_path_lengths, predicted_path_lengths] += 1
778
779         if count.max() == 0:
780             count = None
781         else:
782             count = count[
783                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
784             ]
785
786         return nb_total, nb_correct, count
787
788     def produce_results(
789         self, n_epoch, model, result_dir, logger, deterministic_synthesis
790     ):
791         train_nb_total, train_nb_correct, count = self.compute_error(
792             model,
793             "train",
794             nb_to_use=1000,
795             deterministic_synthesis=deterministic_synthesis,
796         )
797         logger(
798             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
799         )
800
801         test_nb_total, test_nb_correct, count = self.compute_error(
802             model,
803             "test",
804             nb_to_use=1000,
805             deterministic_synthesis=deterministic_synthesis,
806         )
807         logger(
808             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
809         )
810
811         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
812
813         if count is not None:
814             proportion_optimal = count.diagonal().sum().float() / count.sum()
815             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
816             with open(
817                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
818             ) as f:
819                 for i in range(count.size(0)):
820                     for j in range(count.size(1)):
821                         eol = " " if j < count.size(1) - 1 else "\n"
822                         f.write(f"{count[i,j]}{eol}")
823
824         input = self.test_input[:48]
825         result = input.clone()
826         ar_mask = result.new_zeros(result.size())
827         ar_mask[:, self.height * self.width :] = 1
828         result *= 1 - ar_mask
829         masked_inplace_autoregression(
830             model,
831             self.batch_size,
832             result,
833             ar_mask,
834             deterministic_synthesis,
835             device=self.device,
836         )
837
838         mazes, paths = self.seq2map(input)
839         _, predicted_paths = self.seq2map(result)
840
841         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
842         maze.save_image(
843             filename,
844             mazes=mazes,
845             target_paths=paths,
846             predicted_paths=predicted_paths,
847             path_correct=maze.path_correctness(mazes, predicted_paths),
848             path_optimal=maze.path_optimality(paths, predicted_paths),
849         )
850         logger(f"wrote {filename}")
851
852
853 ######################################################################
854
855
856 import snake
857
858
859 class Snake(Task):
860     def __init__(
861         self,
862         nb_train_samples,
863         nb_test_samples,
864         batch_size,
865         height,
866         width,
867         nb_colors,
868         length,
869         prompt_length,
870         device=torch.device("cpu"),
871     ):
872         super().__init__()
873
874         self.batch_size = batch_size
875         self.height = height
876         self.width = width
877         self.device = device
878         self.prompt_length = prompt_length
879
880         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
881             nb_train_samples,
882             height,
883             width,
884             nb_colors,
885             length,
886             prompt_length,
887             self.device,
888         )
889         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
890             nb_test_samples,
891             height,
892             width,
893             nb_colors,
894             length,
895             prompt_length,
896             self.device,
897         )
898
899         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
900
901     def batches(self, split="train", nb_to_use=-1, desc=None):
902         assert split in {"train", "test"}
903         input = self.train_input if split == "train" else self.test_input
904         if nb_to_use > 0:
905             input = input[:nb_to_use]
906         if desc is None:
907             desc = f"epoch-{split}"
908         for batch in tqdm.tqdm(
909             input.split(self.batch_size), dynamic_ncols=True, desc=desc
910         ):
911             yield batch
912
913     def vocabulary_size(self):
914         return self.nb_codes
915
916     def produce_results(
917         self, n_epoch, model, result_dir, logger, deterministic_synthesis
918     ):
919         def compute_nb_correct(input, prior_visits):
920             result = input.clone()
921             i = torch.arange(result.size(1), device=result.device)[None, :]
922             ar_mask = (
923                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
924                 .long()
925                 .expand_as(result)
926             )
927             result *= 1 - ar_mask
928
929             masked_inplace_autoregression(
930                 model,
931                 self.batch_size,
932                 result,
933                 ar_mask,
934                 deterministic_synthesis,
935                 device=self.device,
936             )
937
938             nb_total = ((prior_visits > 0) * ar_mask).sum()
939
940             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
941
942             return nb_total, nb_correct
943
944         test_nb_total, test_nb_correct = compute_nb_correct(
945             self.test_input[:1000], self.test_prior_visits[:1000]
946         )
947
948         logger(
949             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
950         )
951
952         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
953
954
955 ######################################################################
956
957
958 import stack
959
960
961 class Stack(Task):
962     def __init__(
963         self,
964         nb_train_samples,
965         nb_test_samples,
966         batch_size,
967         logger,
968         nb_steps,
969         nb_stacks,
970         nb_digits,
971         fraction_values_for_train=None,
972         device=torch.device("cpu"),
973     ):
974         super().__init__()
975
976         self.batch_size = batch_size
977         self.nb_steps = nb_steps
978         self.nb_stacks = nb_stacks
979         self.nb_digits = nb_digits
980         self.device = device
981
982         if fraction_values_for_train is None:
983             values_for_train = None
984             values_for_test = None
985         else:
986             all = torch.randperm(10**nb_digits)
987             nb_for_train = int(all.size(0) * fraction_values_for_train)
988             values_for_train = all[:nb_for_train]
989             values_for_test = all[nb_for_train:]
990
991         self.train_input, self.train_stack_counts = stack.generate_sequences(
992             nb_train_samples,
993             nb_steps,
994             nb_stacks,
995             nb_digits,
996             values_for_train,
997             self.device,
998         )
999
1000         self.test_input, self.test_stack_counts = stack.generate_sequences(
1001             nb_test_samples,
1002             nb_steps,
1003             nb_stacks,
1004             nb_digits,
1005             values_for_test,
1006             self.device,
1007         )
1008
1009         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
1010         counts = self.test_stack_counts.flatten()[i.flatten()]
1011         counts = F.one_hot(counts).sum(0)
1012         logger(f"test_pop_stack_counts {counts}")
1013
1014         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1015
1016     def batches(self, split="train", nb_to_use=-1, desc=None):
1017         assert split in {"train", "test"}
1018         input = self.train_input if split == "train" else self.test_input
1019         if nb_to_use > 0:
1020             input = input[:nb_to_use]
1021         if desc is None:
1022             desc = f"epoch-{split}"
1023         for batch in tqdm.tqdm(
1024             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1025         ):
1026             yield batch
1027
1028     def vocabulary_size(self):
1029         return self.nb_codes
1030
1031     def produce_results(
1032         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1033     ):
1034         def compute_nb_correct(input):
1035             result = input.clone()
1036             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1037             ar_mask = (result != input).long()
1038             masked_inplace_autoregression(
1039                 model,
1040                 self.batch_size,
1041                 result,
1042                 ar_mask,
1043                 deterministic_synthesis,
1044                 device=self.device,
1045             )
1046
1047             errors = ((result != input).long() * ar_mask).reshape(
1048                 -1, 1 + self.nb_digits
1049             )
1050             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1051
1052             nb_total = ar_mask.max(1).values.sum()
1053             nb_correct = nb_total - errors.max(1).values.sum()
1054
1055             return nb_total, nb_correct
1056
1057         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1058
1059         logger(
1060             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1061         )
1062
1063         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1064
1065         ##############################################################
1066         # Log a few generated sequences
1067         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1068         result = input.clone()
1069         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1070         ar_mask = (result != input).long()
1071
1072         # for n in range(result.size(0)):
1073         # logger(
1074         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1075         # )
1076
1077         masked_inplace_autoregression(
1078             model,
1079             self.batch_size,
1080             result,
1081             ar_mask,
1082             deterministic_synthesis,
1083             device=self.device,
1084         )
1085
1086         for n in range(result.size(0)):
1087             logger(
1088                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1089             )
1090         ##############################################################
1091
1092
1093 ######################################################################
1094
1095 import rpl
1096
1097
1098 class RPL(Task):
1099     def tensorize(self, sequences):
1100         len_max = max([len(x) for x in sequences])
1101         return torch.cat(
1102             [
1103                 torch.tensor(
1104                     [
1105                         [
1106                             self.token2id[str(c)]
1107                             for c in s + ["<nul>"] * (len_max - len(s))
1108                         ]
1109                         for s in sequences
1110                     ]
1111                 )
1112             ],
1113             0,
1114         )
1115
1116     def seq2str(self, seq):
1117         return " ".join([self.id2token[i] for i in seq])
1118
1119     def __init__(
1120         self,
1121         nb_train_samples,
1122         nb_test_samples,
1123         batch_size,
1124         nb_starting_values=3,
1125         max_input=9,
1126         prog_len=6,
1127         nb_runs=5,
1128         no_prog=False,
1129         logger=None,
1130         device=torch.device("cpu"),
1131     ):
1132         super().__init__()
1133
1134         self.batch_size = batch_size
1135         self.device = device
1136         self.no_prog = no_prog
1137
1138         train_sequences = [
1139             rpl.generate(
1140                 nb_starting_values=nb_starting_values,
1141                 nb_result_values_max=4 * nb_starting_values,
1142                 max_input=max_input,
1143                 prog_len=prog_len,
1144                 nb_runs=nb_runs,
1145             )
1146             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1147         ]
1148
1149         test_sequences = [
1150             rpl.generate(
1151                 nb_starting_values=nb_starting_values,
1152                 nb_result_values_max=4 * nb_starting_values,
1153                 max_input=max_input,
1154                 prog_len=prog_len,
1155                 nb_runs=nb_runs,
1156             )
1157             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1158         ]
1159
1160         symbols = list(
1161             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1162         )
1163         val_max = max([x if type(x) is int else 0 for x in symbols])
1164         symbols = list(filter(lambda x: type(x) is str, symbols))
1165         symbols.sort()
1166         symbols += [str(n) for n in range(val_max + 1)]
1167         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1168         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1169
1170         self.t_nul = self.token2id["<nul>"]
1171         self.t_input = self.token2id["<in>"]
1172         self.t_output = self.token2id["<out>"]
1173         self.t_prog = self.token2id["<prg>"]
1174         self.t_end = self.token2id["<end>"]
1175
1176         self.train_input = self.tensorize(train_sequences)
1177         self.test_input = self.tensorize(test_sequences)
1178
1179         if no_prog:
1180             # Excise the program from every train and test example
1181             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1182                 None, :
1183             ]
1184             p = (
1185                 ((self.train_input == self.t_prog).long() * k)
1186                 .max(1, keepdim=True)
1187                 .values
1188             )
1189             self.train_input = (
1190                 self.train_input * (k <= p).long()
1191                 + self.t_end * (k == p + 1).long()
1192                 + self.t_nul * (k > p + 1).long()
1193             )
1194             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1195                 None, :
1196             ]
1197             p = (
1198                 ((self.test_input == self.t_prog).long() * k)
1199                 .max(1, keepdim=True)
1200                 .values
1201             )
1202             self.test_input = (
1203                 self.test_input * (k <= p).long()
1204                 + self.t_end * (k == p + 1).long()
1205                 + self.t_nul * (k > p + 1).long()
1206             )
1207
1208         if logger is not None:
1209             logger(f"value_max {val_max}")
1210             for x in self.train_input[:25]:
1211                 end = (x != self.t_nul).nonzero().max().item() + 1
1212                 seq = [self.id2token[i.item()] for i in x[:end]]
1213                 s = " ".join(seq)
1214                 logger(f"example_seq {s}")
1215
1216         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1217
1218     def batches(self, split="train", nb_to_use=-1, desc=None):
1219         assert split in {"train", "test"}
1220         input = self.train_input if split == "train" else self.test_input
1221         if nb_to_use > 0:
1222             input = input[:nb_to_use]
1223         if desc is None:
1224             desc = f"epoch-{split}"
1225         for batch in tqdm.tqdm(
1226             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1227         ):
1228             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1229             batch = batch[:, :last].to(self.device)
1230             yield batch
1231
1232     def vocabulary_size(self):
1233         return self.nb_codes
1234
1235     def produce_results(
1236         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1237     ):
1238         # --------------------------------------------------------------------
1239         def compute_nb_errors_prog(input, nb_to_log=0):
1240             result = input.clone()
1241             s = (result == self.t_prog).long()
1242             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1243             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1244
1245             masked_inplace_autoregression(
1246                 model,
1247                 self.batch_size,
1248                 result,
1249                 ar_mask,
1250                 deterministic_synthesis,
1251                 device=self.device,
1252             )
1253
1254             sum_nb_total, sum_nb_errors = 0, 0
1255             for one_input, one_result in zip(input, result):
1256                 seq = [self.id2token[i.item()] for i in one_result]
1257                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1258                 sum_nb_total += 1
1259                 sum_nb_errors += 0 if nb_errors == 0 else 1
1260                 if nb_to_log > 0:
1261                     gt_seq = [self.id2token[i.item()] for i in one_input]
1262                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1263                     gt_prog = " ".join([str(x) for x in gt_prog])
1264                     prog = " ".join([str(x) for x in prog])
1265                     comment = "*" if nb_errors == 0 else "-"
1266                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1267                     for start_stack, target_stack, result_stack, correct in stacks:
1268                         comment = "*" if correct else "-"
1269                         start_stack = " ".join([str(x) for x in start_stack])
1270                         target_stack = " ".join([str(x) for x in target_stack])
1271                         result_stack = " ".join([str(x) for x in result_stack])
1272                         logger(
1273                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1274                         )
1275                     nb_to_log -= 1
1276
1277             return sum_nb_total, sum_nb_errors
1278
1279         # --------------------------------------------------------------------
1280         def compute_nb_errors_output(input, nb_to_log=0):
1281             result = input.clone()
1282             k = torch.arange(result.size(1), device=result.device)[None, :]
1283             last_output_idx = (
1284                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1285             )
1286             first_prog_idx = (
1287                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1288             )
1289             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1290             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1291
1292             masked_inplace_autoregression(
1293                 model,
1294                 self.batch_size,
1295                 result,
1296                 ar_mask,
1297                 deterministic_synthesis,
1298                 device=self.device,
1299             )
1300
1301             sum_nb_total, sum_nb_errors = 0, 0
1302             for one_input, one_result, i, j in zip(
1303                 input, result, last_output_idx, first_prog_idx
1304             ):
1305                 seq = [self.id2token[i.item()] for i in one_result]
1306                 sum_nb_total += 1
1307                 correct = (one_input - one_result).abs().max() == 0
1308                 sum_nb_errors += 0 if correct else 1
1309                 if nb_to_log > 0:
1310                     result_stack = [
1311                         self.id2token[i.item()] for i in one_result[i : j + 1]
1312                     ]
1313                     target_stack = [
1314                         self.id2token[i.item()] for i in one_input[i : j + 1]
1315                     ]
1316                     comment = "*" if correct else "-"
1317                     result_stack = " ".join([str(x) for x in result_stack])
1318                     target_stack = " ".join([str(x) for x in target_stack])
1319                     logger(
1320                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1321                     )
1322                     nb_to_log -= 1
1323
1324             return sum_nb_total, sum_nb_errors
1325
1326         # --------------------------------------------------------------------
1327
1328         if not self.no_prog:
1329             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1330                 self.test_input[:1000].to(self.device), nb_to_log=10
1331             )
1332
1333             logger(
1334                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1335             )
1336
1337             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1338
1339         test_nb_total, test_nb_errors = compute_nb_errors_output(
1340             self.test_input[:1000].to(self.device), nb_to_log=10
1341         )
1342
1343         logger(
1344             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1345         )
1346
1347         if save_attention_image is None:
1348             logger("no save_attention_image (is pycairo installed?)")
1349         else:
1350             ns = torch.randint(self.test_input.size(0), (1,)).item()
1351             input = self.test_input[ns : ns + 1].clone()
1352             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1353             input = input[:, :last].to(self.device)
1354
1355             with torch.autograd.no_grad():
1356                 t = model.training
1357                 model.eval()
1358                 model.record_attention(True)
1359                 model(BracketedSequence(input))
1360                 model.train(t)
1361                 ram = model.retrieve_attention()
1362                 model.record_attention(False)
1363
1364             tokens_output = [self.id2token[i.item()] for i in input[0]]
1365             tokens_input = ["n/a"] + tokens_output[:-1]
1366             for n_head in range(ram[0].size(1)):
1367                 filename = os.path.join(
1368                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1369                 )
1370                 attention_matrices = [m[0, n_head] for m in ram]
1371                 save_attention_image(
1372                     filename,
1373                     tokens_input,
1374                     tokens_output,
1375                     attention_matrices,
1376                     k_top=10,
1377                     # min_total_attention=0.9,
1378                     token_gap=12,
1379                     layer_gap=50,
1380                 )
1381                 logger(f"wrote {filename}")
1382
1383
1384 ######################################################################
1385
1386
1387 import expr
1388
1389
1390 class Expr(Task):
1391     def tensorize(self, sequences):
1392         len_max = max([len(x) for x in sequences])
1393         return torch.cat(
1394             [
1395                 torch.tensor(
1396                     [
1397                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1398                         for s in sequences
1399                     ]
1400                 )
1401             ],
1402             0,
1403         ).to(self.device)
1404
1405     def __init__(
1406         self,
1407         nb_train_samples,
1408         nb_test_samples,
1409         nb_variables,
1410         sequence_length,
1411         operand_max,
1412         result_max,
1413         batch_size,
1414         device=torch.device("cpu"),
1415     ):
1416         super().__init__()
1417
1418         self.batch_size = batch_size
1419         self.device = device
1420
1421         train_sequences = expr.generate_sequences(
1422             nb_train_samples,
1423             nb_variables=nb_variables,
1424             length=sequence_length,
1425             operand_max=operand_max,
1426             result_max=result_max,
1427         )
1428
1429         test_sequences = expr.generate_sequences(
1430             nb_test_samples,
1431             nb_variables=nb_variables,
1432             length=sequence_length,
1433             operand_max=operand_max,
1434             result_max=result_max,
1435         )
1436
1437         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1438         symbols.sort()
1439
1440         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1441         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1442
1443         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1444
1445         self.train_input = self.tensorize(train_sequences)
1446         self.test_input = self.tensorize(test_sequences)
1447
1448         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1449
1450     def batches(self, split="train", nb_to_use=-1, desc=None):
1451         assert split in {"train", "test"}
1452         input = self.train_input if split == "train" else self.test_input
1453         if nb_to_use > 0:
1454             input = input[:nb_to_use]
1455         if desc is None:
1456             desc = f"epoch-{split}"
1457         for batch in tqdm.tqdm(
1458             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1459         ):
1460             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1461             batch = batch[:, :last]
1462             yield batch
1463
1464     def vocabulary_size(self):
1465         return self.nb_codes
1466
1467     def seq2str(self, s):
1468         return "".join([self.id2char[k.item()] for k in s])
1469
1470     def produce_results(
1471         self,
1472         n_epoch,
1473         model,
1474         result_dir,
1475         logger,
1476         deterministic_synthesis,
1477         input_file=None,
1478     ):
1479         def compute_nb_correct(input):
1480             result = input.clone()
1481             s = (result == self.space).long()
1482             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1483             result = (1 - ar_mask) * result + ar_mask * self.filler
1484             masked_inplace_autoregression(
1485                 model,
1486                 self.batch_size,
1487                 result,
1488                 ar_mask,
1489                 deterministic_synthesis,
1490                 device=self.device,
1491             )
1492
1493             nb_total = input.size(0)
1494             nb_correct = (input == result).long().min(1).values.sum()
1495
1496             #######################################################################
1497             # Comput predicted vs. true variable values
1498
1499             nb_delta = torch.zeros(5, dtype=torch.int64)
1500             nb_missed = 0
1501
1502             values_input = expr.extract_results([self.seq2str(s) for s in input])
1503             values_result = expr.extract_results([self.seq2str(s) for s in result])
1504
1505             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1506
1507             with open(filename, "w") as f:
1508                 for i, r in zip(values_input, values_result):
1509                     for n, vi in i.items():
1510                         vr = r.get(n)
1511                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1512
1513                         if vr is None or vr < 0:
1514                             nb_missed += 1
1515                         else:
1516                             d = abs(vr - vi)
1517                             if d >= nb_delta.size(0):
1518                                 nb_missed += 1
1519                             else:
1520                                 nb_delta[d] += 1
1521
1522             ######################################################################
1523
1524             return nb_total, nb_correct, nb_delta, nb_missed
1525
1526         (
1527             test_nb_total,
1528             test_nb_correct,
1529             test_nb_delta,
1530             test_nb_missed,
1531         ) = compute_nb_correct(self.test_input[:10000])
1532
1533         logger(
1534             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1535         )
1536
1537         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1538
1539         nb_total = test_nb_delta.sum() + test_nb_missed
1540         for d in range(test_nb_delta.size(0)):
1541             logger(
1542                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1543             )
1544         logger(
1545             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1546         )
1547
1548         ##############################################################
1549         # Log a few generated sequences
1550         if input_file is None:
1551             input = self.test_input[:10]
1552         else:
1553             with open(input_file, "r") as f:
1554                 sequences = [e.strip() for e in f.readlines()]
1555                 sequences = [s + " " + "#" * 50 for s in sequences]
1556                 input = self.tensorize(sequences)
1557
1558         result = input.clone()
1559         s = (result == self.space).long()
1560         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1561         result = (1 - ar_mask) * result + ar_mask * self.filler
1562
1563         for n in range(result.size(0)):
1564             logger(f"test_before {self.seq2str(result[n])}")
1565
1566         masked_inplace_autoregression(
1567             model,
1568             self.batch_size,
1569             result,
1570             ar_mask,
1571             deterministic_synthesis,
1572             device=self.device,
1573         )
1574
1575         correct = (1 - ar_mask) * self.space + ar_mask * input
1576         for n in range(result.size(0)):
1577             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1578             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1579             logger(f"truth       {self.seq2str(correct[n])}")
1580         ##############################################################
1581
1582
1583 ######################################################################
1584
1585 import grid
1586
1587
1588 class Grid(Task):
1589     # Make a tensor from a list of strings
1590     def str2tensor(self, descr):
1591         token_descr = [s.strip().split(" ") for s in descr]
1592         l = max([len(s) for s in token_descr])
1593         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1594         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1595         return torch.tensor(id_descr, device=self.device)
1596
1597     # Make a list of strings from a tensor
1598     def tensor2str(self, x):
1599         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1600
1601     # trim all the tensors in the tuple z to remove as much token from
1602     # left and right in the first tensor. If z is a tuple, all its
1603     # elements are trimed according to the triming for the first
1604     def trim(self, z, token="#"):
1605         n = self.token2id[token]
1606         if type(z) == tuple:
1607             x = z[0]
1608             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1609             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1610             return tuple([t[:, a:b] for t in z])
1611         else:
1612             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1613             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1614             return z[:, a:b]
1615
1616     ######################
1617
1618     def __init__(
1619         self,
1620         nb_train_samples,
1621         nb_test_samples,
1622         batch_size,
1623         size,
1624         fraction_play=0.0,
1625         logger=None,
1626         device=torch.device("cpu"),
1627     ):
1628         super().__init__()
1629
1630         self.device = device
1631         self.batch_size = batch_size
1632         self.grid_factory = grid.GridFactory(size=size)
1633         self.fraction_play = fraction_play
1634
1635         if logger is not None:
1636             logger(
1637                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1638             )
1639
1640         self.train_descr = self.grid_factory.generate_samples(
1641             nb=nb_train_samples,
1642             fraction_play=fraction_play,
1643             progress_bar=lambda r: tqdm.tqdm(r),
1644         )
1645
1646         self.test_descr = self.grid_factory.generate_samples(
1647             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1648         )
1649
1650         if fraction_play > 0:
1651             self.play_descr = self.grid_factory.generate_samples(
1652                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1653             )
1654         else:
1655             self.play_descr = []
1656
1657         # Build the tokenizer
1658         tokens = set()
1659         for d in [self.train_descr, self.test_descr, self.play_descr]:
1660             for s in d:
1661                 for t in s.strip().split(" "):
1662                     tokens.add(t)
1663         # make this set a sorted list to get the same tensors given
1664         # the same descr
1665         tokens = list(tokens)
1666         tokens.sort()
1667         tokens = ["#"] + tokens
1668         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1669         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1670         self.t_nul = self.token2id["#"]
1671         self.t_true = self.token2id["true"]
1672         self.t_false = self.token2id["false"]
1673         self.t_pipe = self.token2id["|"]
1674
1675         # Tokenize the train and test sets
1676         self.train_input = self.str2tensor(self.train_descr)
1677         self.test_input = self.str2tensor(self.test_descr)
1678         self.play_input = (
1679             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1680         )
1681
1682     def batches(self, split="train"):
1683         assert split in {"train", "test"}
1684         input = self.train_input if split == "train" else self.test_input
1685         for batch in tqdm.tqdm(
1686             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1687         ):
1688             yield self.trim(batch)
1689
1690     def vocabulary_size(self):
1691         return len(self.token2id)
1692
1693     def produce_results(
1694         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1695     ):
1696         correct = self.test_input[:1000]
1697         result = correct.clone()
1698         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1699         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1700
1701         logger(f"----------------------------------------------------------")
1702
1703         for e in self.tensor2str(result[:10]):
1704             logger(f"test_before {e}")
1705
1706         masked_inplace_autoregression(
1707             model,
1708             self.batch_size,
1709             result,
1710             ar_mask,
1711             deterministic_synthesis,
1712             device=self.device,
1713         )
1714
1715         logger(f"----------------------------------------------------------")
1716
1717         for e in self.tensor2str(result[:10]):
1718             logger(f"test_after  {e}")
1719
1720         logger(f"----------------------------------------------------------")
1721
1722         nb_total = ar_mask.sum().item()
1723         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1724
1725         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1726         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1727
1728         if self.play_input is not None:
1729             result = self.play_input.clone()
1730             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1731             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1732
1733             logger(f"----------------------------------------------------------")
1734
1735             for e in self.tensor2str(result[:10]):
1736                 logger(f"play_before {e}")
1737
1738             masked_inplace_autoregression(
1739                 model,
1740                 self.batch_size,
1741                 result,
1742                 ar_mask,
1743                 deterministic_synthesis,
1744                 device=self.device,
1745             )
1746
1747             logger(f"----------------------------------------------------------")
1748
1749             for e in self.tensor2str(result[:10]):
1750                 logger(f"play_after  {e}")
1751
1752             logger(f"----------------------------------------------------------")
1753
1754
1755 ######################################################################
1756
1757 import qmlp
1758
1759
1760 class QMLP(Task):
1761     ######################
1762
1763     def __init__(
1764         self,
1765         nb_train_samples,
1766         nb_test_samples,
1767         batch_size,
1768         result_dir,
1769         logger=None,
1770         device=torch.device("cpu"),
1771     ):
1772         super().__init__()
1773
1774         self.device = device
1775         self.batch_size = batch_size
1776         self.nb_samples_per_mlp = 256
1777
1778         if logger is not None:
1779             logger(
1780                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1781             )
1782
1783         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1784             nb_mlps=nb_train_samples + nb_test_samples,
1785             nb_samples=self.nb_samples_per_mlp,
1786             device=self.device,
1787             batch_size=64,
1788             nb_epochs=250,
1789             nb_mlps_per_batch=1024,
1790         )
1791
1792         self.train_input = seq[:nb_train_samples]
1793         self.train_q_test_set = q_test_set[:nb_train_samples]
1794         self.train_ref_test_errors = test_error[:nb_train_samples]
1795         self.test_input = seq[nb_train_samples:]
1796         self.test_q_test_set = q_test_set[nb_train_samples:]
1797         self.test_ref_test_errors = test_error[nb_train_samples:]
1798
1799         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1800         with open(filename, "w") as f:
1801             for e in self.train_ref_test_errors:
1802                 f.write(f"{e}\n")
1803
1804         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1805         with open(filename, "w") as f:
1806             for e in self.test_ref_test_errors:
1807                 f.write(f"{e}\n")
1808
1809         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1810
1811     def batches(self, split="train"):
1812         assert split in {"train", "test"}
1813         input = self.train_input if split == "train" else self.test_input
1814         for batch in tqdm.tqdm(
1815             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1816         ):
1817             yield batch
1818
1819     def vocabulary_size(self):
1820         return self.nb_codes
1821
1822     def produce_results(
1823         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1824     ):
1825         correct = self.test_input[:1000]
1826         result = correct.clone()
1827         ar_mask = (
1828             torch.arange(result.size(1), device=result.device)
1829             > self.nb_samples_per_mlp * 3 + 1
1830         ).long()[None, :]
1831         ar_mask = ar_mask.expand_as(result)
1832         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1833
1834         masked_inplace_autoregression(
1835             model,
1836             self.batch_size,
1837             result,
1838             ar_mask,
1839             deterministic_synthesis,
1840             device=self.device,
1841         )
1842
1843         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1844         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1845         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1846
1847         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1848         with open(filename, "w") as f:
1849             for e in error_test:
1850                 f.write(f"{e}\n")
1851
1852
1853 ######################################################################