Oups
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     logit_biases=None,
31     progress_bar_desc="autoregression",
32     device=torch.device("cpu"),
33 ):
34     assert input.size() == ar_mask.size()
35
36     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
37
38     if progress_bar_desc is not None:
39         batches = tqdm.tqdm(
40             batches,
41             dynamic_ncols=True,
42             desc=progress_bar_desc,
43             total=(input.size(0) + batch_size - 1) // batch_size,
44         )
45
46     with torch.autograd.no_grad():
47         t = model.training
48         model.eval()
49
50         for input, ar_mask in batches:
51             model.masked_inplace_autoregression(
52                 input,
53                 ar_mask,
54                 deterministic_synthesis,
55                 forbidden_tokens,
56                 logit_biases,
57             )
58
59         model.train(t)
60
61
62 ######################################################################
63
64
65 class Task:
66     def batches(self, split="train", nb_to_use=-1, desc=None):
67         pass
68
69     def vocabulary_size(self):
70         pass
71
72     def produce_results(
73         self, n_epoch, model, result_dir, logger, deterministic_synthesis
74     ):
75         pass
76
77
78 class TaskFromFile(Task):
79     def tensorize(self, pairs, shuffle):
80         len_max = max([len(x[0]) for x in pairs])
81
82         input = torch.cat(
83             [
84                 torch.tensor(
85                     [
86                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
87                         for s in pairs
88                     ]
89                 )
90             ],
91             0,
92         ).to("cpu")
93
94         pred_mask = torch.cat(
95             [
96                 torch.tensor(
97                     [
98                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
99                         for s in pairs
100                     ]
101                 )
102             ],
103             0,
104         ).to("cpu")
105
106         if shuffle:
107             i = torch.randperm(input.size(0))
108             input = input[i].contiguous()
109             pred_mask = pred_mask[i].contiguous()
110
111         return input, pred_mask
112
113     # trim all the tensors in the tuple z to remove as much token from
114     # left and right in the first tensor. If z is a tuple, all its
115     # elements are trimed according to the triming for the first
116     def trim(self, z, token="#"):
117         n = self.char2id[token]
118         if type(z) == tuple:
119             x = z[0]
120             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
121             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
122             return tuple([t[:, a:b] for t in z])
123         else:
124             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
125             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
126             return z[:, a:b]
127
128     def __init__(
129         self,
130         train_filename,
131         test_filename,
132         nb_train_samples,
133         nb_test_samples,
134         batch_size,
135         shuffle=False,
136         device=torch.device("cpu"),
137     ):
138         self.batch_size = batch_size
139         self.device = device
140
141         def read_file(filename, nb=-1):
142             pairs = []
143             with open(filename, "r") as f:
144                 while True:
145                     sequence = f.readline().strip()
146                     if not sequence:
147                         break
148                     pred_mask = f.readline().strip()
149                     assert len(sequence) == len(pred_mask)
150                     assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
151                     pairs.append((sequence, pred_mask))
152                     if len(pairs) == nb:
153                         break
154
155             if nb > 0:
156                 pairs = pairs[:nb]
157                 assert len(pairs) == nb
158
159             return pairs
160
161         train_pairs = read_file(train_filename, nb_train_samples)
162         test_pairs = read_file(test_filename, nb_test_samples)
163
164         symbols = ["#"] + list(
165             set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
166         )
167         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
168         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
169
170         self.train_input, self.train_pred_masks = self.tensorize(
171             train_pairs, shuffle=shuffle
172         )
173         self.test_input, self.test_pred_masks = self.tensorize(
174             test_pairs, shuffle=shuffle
175         )
176
177     def batches(self, split="train", nb_to_use=-1, desc=None):
178         assert split in {"train", "test"}
179         input = self.train_input if split == "train" else self.test_input
180         if nb_to_use > 0:
181             input = input[:nb_to_use]
182         if desc is None:
183             desc = f"epoch-{split}"
184         for batch in tqdm.tqdm(
185             input.split(self.batch_size), dynamic_ncols=True, desc=desc
186         ):
187             yield self.trim(batch).to(self.device)
188
189     def vocabulary_size(self):
190         return len(self.char2id)
191
192     def tensor2str(self, t):
193         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
194
195     def produce_results(
196         self, n_epoch, model, result_dir, logger, deterministic_synthesis
197     ):
198         correct = self.trim(self.test_input[:1000]).to(self.device)
199         result = correct.clone()
200         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
201         ar_mask = (pred_mask > 0).long()
202         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
203
204         logger(f"----------------------------------------------------------")
205
206         for e in self.tensor2str(result[:50]):
207             logger(f"test_before {e}")
208
209         masked_inplace_autoregression(
210             model,
211             self.batch_size,
212             result,
213             ar_mask,
214             deterministic_synthesis,
215             device=self.device,
216         )
217
218         logger(f"----------------------------------------------------------")
219
220         for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
221             logger(f"test_after  {e}")
222             logger(f"correct     {c}")
223
224         logger(f"----------------------------------------------------------")
225
226         err_mask = (pred_mask == 2).long()
227         nb_total = err_mask.sum().item()
228         nb_correct = ((correct == result).long() * err_mask).sum().item()
229
230         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
231         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
232
233
234 ####################
235
236 import problems
237
238
239 class SandBox(Task):
240     def __init__(
241         self,
242         problem,
243         nb_train_samples,
244         nb_test_samples,
245         batch_size,
246         logger=None,
247         device=torch.device("cpu"),
248         max_nb_codes=1024,
249     ):
250         super().__init__()
251
252         self.batch_size = batch_size
253         self.device = device
254         self.problem = problem
255
256         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
257             nb_train_samples
258         )
259         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
260             nb_test_samples
261         )
262
263         self.train_input, self.train_ar_mask = self.train_input.to(
264             device
265         ), self.train_ar_mask.to(device)
266         self.test_input, self.test_ar_mask = self.test_input.to(
267             device
268         ), self.test_ar_mask.to(device)
269
270         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
271
272         # A bit of paranoia never hurts
273         assert self.nb_codes <= max_nb_codes
274         assert self.train_input.min() >= 0
275         assert self.test_input.min() >= 0
276         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
277             (0,),
278             (1,),
279             (0, 1),
280         }
281         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
282             (0,),
283             (1,),
284             (0, 1),
285         }
286
287         if logger is not None:
288             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
289                 logger(f"train_sequences {self.problem.seq2str(s)}")
290                 a = "".join(["01"[x.item()] for x in a])
291                 logger(f"                {a}")
292
293     def batches(self, split="train", nb_to_use=-1, desc=None):
294         assert split in {"train", "test"}
295         input = self.train_input if split == "train" else self.test_input
296         if nb_to_use > 0:
297             input = input[:nb_to_use]
298         if desc is None:
299             desc = f"epoch-{split}"
300         for batch in tqdm.tqdm(
301             input.split(self.batch_size), dynamic_ncols=True, desc=desc
302         ):
303             yield batch
304
305     def vocabulary_size(self):
306         return self.nb_codes
307
308     def produce_results(
309         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
310     ):
311         def compute_accuracy(input, ar_mask, logger=None):
312             input, ar_mask = input[:nmax], ar_mask[:nmax]
313             result = input.clone() * (1 - ar_mask)
314
315             masked_inplace_autoregression(
316                 model,
317                 self.batch_size,
318                 result,
319                 ar_mask,
320                 deterministic_synthesis,
321                 progress_bar_desc=None,
322                 device=self.device,
323             )
324
325             log_ground_truth = ar_mask.min() == 0
326
327             if logger is not None:
328                 for sp, st in zip(result[:10], input[:10]):
329                     logger(
330                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
331                     )
332                     if log_ground_truth:
333                         logger(
334                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
335                         )
336
337             nb_total, nb_correct = self.problem.compute_nb_correct(
338                 input, ar_mask, result
339             )
340
341             # nb_total = ar_mask.sum().item()
342             # nb_correct = ((result == input).long() * ar_mask).sum().item()
343
344             return nb_total, nb_correct
345
346         train_nb_total, train_nb_correct = compute_accuracy(
347             self.train_input, self.train_ar_mask
348         )
349
350         logger(
351             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
352         )
353
354         test_nb_total, test_nb_correct = compute_accuracy(
355             self.test_input, self.test_ar_mask, logger
356         )
357
358         logger(
359             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
360         )
361
362         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
363
364         if save_attention_image is not None:
365             for k in range(10):
366                 ns = torch.randint(self.test_input.size(0), (1,)).item()
367                 input = self.test_input[ns : ns + 1].clone()
368
369                 with torch.autograd.no_grad():
370                     t = model.training
371                     model.eval()
372                     # model.record_attention(True)
373                     model(BracketedSequence(input))
374                     model.train(t)
375                     # ram = model.retrieve_attention()
376                     # model.record_attention(False)
377
378                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
379                 # tokens_input = ["n/a"] + tokens_output[:-1]
380                 # for n_head in range(ram[0].size(1)):
381                 # filename = os.path.join(
382                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
383                 # )
384                 # attention_matrices = [m[0, n_head] for m in ram]
385                 # save_attention_image(
386                 # filename,
387                 # tokens_input,
388                 # tokens_output,
389                 # attention_matrices,
390                 # k_top=10,
391                 ##min_total_attention=0.9,
392                 # token_gap=12,
393                 # layer_gap=50,
394                 # )
395                 # logger(f"wrote {filename}")
396
397
398 ######################################################################
399
400 import picoclvr
401
402
403 class PicoCLVR(Task):
404     # Make a tensor from a list of strings
405     def tensorize(self, descr):
406         token_descr = [s.strip().split(" ") for s in descr]
407         l = max([len(s) for s in token_descr])
408         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
409         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
410         return torch.tensor(id_descr, device=self.device)
411
412     # Make a list of strings from a tensor
413     def detensorize(self, x):
414         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
415
416     # trim all the tensors in the tuple z to remove as much token from
417     # left and right in the first tensor. If z is a tuple, all its
418     # elements are trimed according to the triming for the first
419     def trim(self, z, token="<nul>"):
420         n = self.token2id[token]
421         if type(z) == tuple:
422             x = z[0]
423             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
424             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
425             return tuple([t[:, a:b] for t in z])
426         else:
427             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
428             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
429             return z[:, a:b]
430
431     ######################
432
433     def __init__(
434         self,
435         nb_train_samples,
436         nb_test_samples,
437         batch_size,
438         height,
439         width,
440         nb_colors=5,
441         logger=None,
442         device=torch.device("cpu"),
443         pruner_train=None,
444         pruner_eval=None,
445     ):
446         super().__init__()
447
448         def generate_descr(nb, cache_suffix, pruner):
449             return picoclvr.generate(
450                 nb,
451                 height=self.height,
452                 width=self.width,
453                 nb_colors=nb_colors,
454                 pruner=pruner,
455             )
456
457         self.height = height
458         self.width = width
459         self.batch_size = batch_size
460         self.device = device
461         self.pruner_train = pruner_train
462         self.pruner_eval = pruner_eval
463
464         if logger is not None:
465             logger(
466                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
467             )
468
469         self.train_descr = generate_descr(
470             nb_train_samples, "train", pruner=self.pruner_train
471         )
472         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
473
474         # Build the tokenizer
475         tokens = {"<nul>", "<img>"}
476         for d in [self.train_descr, self.test_descr]:
477             for s in d:
478                 for t in s.strip().split(" "):
479                     tokens.add(t)
480         # make this set a sorted list to get the same tensors given
481         # the same descr
482         tokens = list(tokens)
483         tokens.sort()
484         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
485         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
486         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
487
488         # Tokenize the train and test sets
489         self.train_input = self.tensorize(self.train_descr)
490         self.test_input = self.tensorize(self.test_descr)
491
492     def batches(self, split="train", nb_to_use=-1, desc=None):
493         assert split in {"train", "test"}
494         input = self.train_input if split == "train" else self.test_input
495         for batch in tqdm.tqdm(
496             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
497         ):
498             yield self.trim(batch)
499
500     def vocabulary_size(self):
501         return len(self.token2id)
502
503     def compute_missing_properties(
504         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
505     ):
506         acc_nb_requested_properties = []
507         acc_nb_missing_properties = []
508         acc_nb_results = 0
509
510         for input in tqdm.tqdm(
511             self.test_input.split(self.batch_size),
512             dynamic_ncols=True,
513             desc=f"test-properties",
514         ):
515             result = input.clone()
516             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
517             result = (1 - ar_mask) * result + ar_mask * self.t_nul
518             masked_inplace_autoregression(
519                 model,
520                 self.batch_size,
521                 result,
522                 ar_mask,
523                 deterministic_synthesis,
524                 progress_bar_desc=None,
525                 device=self.device,
526             )
527
528             result_descr = self.detensorize(result)
529             np = picoclvr.nb_properties(
530                 result_descr,
531                 height=self.height,
532                 width=self.width,
533                 pruner=pruner,
534             )
535             nb_requested_properties, _, nb_missing_properties = zip(*np)
536             acc_nb_requested_properties += nb_requested_properties
537             acc_nb_missing_properties += nb_missing_properties
538             acc_nb_results += len(result_descr)
539
540         nb_requested_properties = sum(acc_nb_requested_properties)
541         nb_missing_properties = sum(acc_nb_missing_properties)
542
543         prefix = "" if pruner is None else "pruned_"
544         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
545         logger(
546             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
547         )
548         logger(
549             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
550         )
551
552         logger(
553             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
554         )
555
556     ######################################################################
557
558     def produce_results(
559         self, n_epoch, model, result_dir, logger, deterministic_synthesis
560     ):
561         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
562
563         if self.pruner_eval is not None:
564             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
565
566         nb_tokens_to_generate = self.height * self.width + 3
567         result_descr = []
568         nb_per_primer = 8
569         primer = []
570
571         for primer_descr in [
572             "red above green <sep> green top <sep> blue right of red",
573             "there is red <sep> there is yellow <sep> there is blue",
574             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
575             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
576         ]:
577             primer += [primer_descr + " <img>"] * nb_per_primer
578
579         result = self.tensorize(primer)
580         fill = result.new_full(
581             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
582         )
583         result = torch.cat((result, fill), 1)
584         ar_mask = (result == self.t_nul).long()
585         masked_inplace_autoregression(
586             model,
587             self.batch_size,
588             result,
589             ar_mask,
590             deterministic_synthesis,
591             device=self.device,
592         )
593         result_descr = self.detensorize(result)
594
595         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
596
597         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
598         acc_nb_results = len(result_descr)
599
600         nb_requested_properties = sum(acc_nb_requested_properties)
601         nb_missing_properties = sum(acc_nb_missing_properties)
602
603         prefix = "demo_"
604         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
605         logger(
606             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
607         )
608         logger(
609             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
610         )
611
612         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
613
614         if img.dim() == 5:
615             if img.size(1) == 1:
616                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
617             else:
618                 img = torch.cat(
619                     [
620                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
621                         for x in img
622                     ],
623                     0,
624                 )
625
626         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
627         torchvision.utils.save_image(
628             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
629         )
630         logger(f"wrote {image_name}")
631
632
633 ######################################################################
634
635
636 class MNIST(Task):
637     def __init__(
638         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
639     ):
640         super().__init__()
641
642         self.nb_train_samples = (nb_train_samples,)
643         self.nb_test_samples = (nb_test_samples,)
644         self.batch_size = batch_size
645         self.device = device
646         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
647         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
648         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
649         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
650
651     def batches(self, split="train", nb_to_use=-1, desc=None):
652         assert split in {"train", "test"}
653         input = self.train_input if split == "train" else self.test_input
654         if nb_to_use > 0:
655             input = input[:nb_to_use]
656         if desc is None:
657             desc = f"epoch-{split}"
658         for batch in tqdm.tqdm(
659             input.split(self.batch_size), dynamic_ncols=True, desc=desc
660         ):
661             yield batch
662
663     def vocabulary_size(self):
664         return 256
665
666     def produce_results(
667         self, n_epoch, model, result_dir, logger, deterministic_synthesis
668     ):
669         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
670         ar_mask = torch.full_like(results, 1)
671         masked_inplace_autoregression(
672             model,
673             self.batch_size,
674             results,
675             ar_mask,
676             deterministic_synthesis,
677             device=self.device,
678         )
679         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
680         torchvision.utils.save_image(
681             1 - results.reshape(-1, 1, 28, 28) / 255.0,
682             image_name,
683             nrow=16,
684             pad_value=0.8,
685         )
686         logger(f"wrote {image_name}")
687
688
689 ######################################################################
690
691 import maze
692
693
694 class Maze(Task):
695     def map2seq(self, *m):
696         return torch.cat([x.flatten(1) for x in m], 1)
697
698     def seq2map(self, s):
699         s = s.reshape(s.size(0), -1, self.height, self.width)
700         return (s[:, k] for k in range(s.size(1)))
701
702     def __init__(
703         self,
704         nb_train_samples,
705         nb_test_samples,
706         batch_size,
707         height,
708         width,
709         nb_walls,
710         device=torch.device("cpu"),
711     ):
712         super().__init__()
713
714         self.batch_size = batch_size
715         self.height = height
716         self.width = width
717         self.device = device
718
719         train_mazes, train_paths, _ = maze.create_maze_data(
720             nb_train_samples,
721             height=height,
722             width=width,
723             nb_walls=nb_walls,
724             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
725         )
726         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
727
728         test_mazes, test_paths, _ = maze.create_maze_data(
729             nb_test_samples,
730             height=height,
731             width=width,
732             nb_walls=nb_walls,
733             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
734         )
735         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
736
737         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
738
739     def batches(self, split="train", nb_to_use=-1, desc=None):
740         assert split in {"train", "test"}
741         input = self.train_input if split == "train" else self.test_input
742         if nb_to_use > 0:
743             input = input[:nb_to_use]
744         if desc is None:
745             desc = f"epoch-{split}"
746         for batch in tqdm.tqdm(
747             input.split(self.batch_size), dynamic_ncols=True, desc=desc
748         ):
749             yield batch
750
751     def vocabulary_size(self):
752         return self.nb_codes
753
754     def compute_error(
755         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
756     ):
757         nb_total, nb_correct = 0, 0
758         count = torch.zeros(
759             self.width * self.height,
760             self.width * self.height,
761             device=self.device,
762             dtype=torch.int64,
763         )
764
765         for input in self.batches(split, nb_to_use):
766             result = input.clone()
767             ar_mask = result.new_zeros(result.size())
768             ar_mask[:, self.height * self.width :] = 1
769             result *= 1 - ar_mask
770             masked_inplace_autoregression(
771                 model,
772                 self.batch_size,
773                 result,
774                 ar_mask,
775                 deterministic_synthesis,
776                 progress_bar_desc=None,
777                 device=self.device,
778             )
779             mazes, paths = self.seq2map(result)
780             path_correctness = maze.path_correctness(mazes, paths)
781             nb_correct += path_correctness.long().sum()
782             nb_total += mazes.size(0)
783
784             optimal_path_lengths = (
785                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
786             )
787             predicted_path_lengths = (
788                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
789             )
790             optimal_path_lengths = optimal_path_lengths[path_correctness]
791             predicted_path_lengths = predicted_path_lengths[path_correctness]
792             count[optimal_path_lengths, predicted_path_lengths] += 1
793
794         if count.max() == 0:
795             count = None
796         else:
797             count = count[
798                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
799             ]
800
801         return nb_total, nb_correct, count
802
803     def produce_results(
804         self, n_epoch, model, result_dir, logger, deterministic_synthesis
805     ):
806         train_nb_total, train_nb_correct, count = self.compute_error(
807             model,
808             "train",
809             nb_to_use=1000,
810             deterministic_synthesis=deterministic_synthesis,
811         )
812         logger(
813             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
814         )
815
816         test_nb_total, test_nb_correct, count = self.compute_error(
817             model,
818             "test",
819             nb_to_use=1000,
820             deterministic_synthesis=deterministic_synthesis,
821         )
822         logger(
823             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
824         )
825
826         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
827
828         if count is not None:
829             proportion_optimal = count.diagonal().sum().float() / count.sum()
830             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
831             with open(
832                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
833             ) as f:
834                 for i in range(count.size(0)):
835                     for j in range(count.size(1)):
836                         eol = " " if j < count.size(1) - 1 else "\n"
837                         f.write(f"{count[i,j]}{eol}")
838
839         input = self.test_input[:48]
840         result = input.clone()
841         ar_mask = result.new_zeros(result.size())
842         ar_mask[:, self.height * self.width :] = 1
843         result *= 1 - ar_mask
844         masked_inplace_autoregression(
845             model,
846             self.batch_size,
847             result,
848             ar_mask,
849             deterministic_synthesis,
850             device=self.device,
851         )
852
853         mazes, paths = self.seq2map(input)
854         _, predicted_paths = self.seq2map(result)
855
856         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
857         maze.save_image(
858             filename,
859             mazes=mazes,
860             target_paths=paths,
861             predicted_paths=predicted_paths,
862             path_correct=maze.path_correctness(mazes, predicted_paths),
863             path_optimal=maze.path_optimality(paths, predicted_paths),
864         )
865         logger(f"wrote {filename}")
866
867
868 ######################################################################
869
870
871 import snake
872
873
874 class Snake(Task):
875     def __init__(
876         self,
877         nb_train_samples,
878         nb_test_samples,
879         batch_size,
880         height,
881         width,
882         nb_colors,
883         length,
884         prompt_length,
885         device=torch.device("cpu"),
886     ):
887         super().__init__()
888
889         self.batch_size = batch_size
890         self.height = height
891         self.width = width
892         self.device = device
893         self.prompt_length = prompt_length
894
895         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
896             nb_train_samples,
897             height,
898             width,
899             nb_colors,
900             length,
901             prompt_length,
902             self.device,
903         )
904         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
905             nb_test_samples,
906             height,
907             width,
908             nb_colors,
909             length,
910             prompt_length,
911             self.device,
912         )
913
914         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
915
916     def batches(self, split="train", nb_to_use=-1, desc=None):
917         assert split in {"train", "test"}
918         input = self.train_input if split == "train" else self.test_input
919         if nb_to_use > 0:
920             input = input[:nb_to_use]
921         if desc is None:
922             desc = f"epoch-{split}"
923         for batch in tqdm.tqdm(
924             input.split(self.batch_size), dynamic_ncols=True, desc=desc
925         ):
926             yield batch
927
928     def vocabulary_size(self):
929         return self.nb_codes
930
931     def produce_results(
932         self, n_epoch, model, result_dir, logger, deterministic_synthesis
933     ):
934         def compute_nb_correct(input, prior_visits):
935             result = input.clone()
936             i = torch.arange(result.size(1), device=result.device)[None, :]
937             ar_mask = (
938                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
939                 .long()
940                 .expand_as(result)
941             )
942             result *= 1 - ar_mask
943
944             masked_inplace_autoregression(
945                 model,
946                 self.batch_size,
947                 result,
948                 ar_mask,
949                 deterministic_synthesis,
950                 device=self.device,
951             )
952
953             nb_total = ((prior_visits > 0) * ar_mask).sum()
954
955             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
956
957             return nb_total, nb_correct
958
959         test_nb_total, test_nb_correct = compute_nb_correct(
960             self.test_input[:1000], self.test_prior_visits[:1000]
961         )
962
963         logger(
964             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
965         )
966
967         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
968
969
970 ######################################################################
971
972
973 import stack
974
975
976 class Stack(Task):
977     def __init__(
978         self,
979         nb_train_samples,
980         nb_test_samples,
981         batch_size,
982         logger,
983         nb_steps,
984         nb_stacks,
985         nb_digits,
986         fraction_values_for_train=None,
987         device=torch.device("cpu"),
988     ):
989         super().__init__()
990
991         self.batch_size = batch_size
992         self.nb_steps = nb_steps
993         self.nb_stacks = nb_stacks
994         self.nb_digits = nb_digits
995         self.device = device
996
997         if fraction_values_for_train is None:
998             values_for_train = None
999             values_for_test = None
1000         else:
1001             all = torch.randperm(10**nb_digits)
1002             nb_for_train = int(all.size(0) * fraction_values_for_train)
1003             values_for_train = all[:nb_for_train]
1004             values_for_test = all[nb_for_train:]
1005
1006         self.train_input, self.train_stack_counts = stack.generate_sequences(
1007             nb_train_samples,
1008             nb_steps,
1009             nb_stacks,
1010             nb_digits,
1011             values_for_train,
1012             self.device,
1013         )
1014
1015         self.test_input, self.test_stack_counts = stack.generate_sequences(
1016             nb_test_samples,
1017             nb_steps,
1018             nb_stacks,
1019             nb_digits,
1020             values_for_test,
1021             self.device,
1022         )
1023
1024         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
1025         counts = self.test_stack_counts.flatten()[i.flatten()]
1026         counts = F.one_hot(counts).sum(0)
1027         logger(f"test_pop_stack_counts {counts}")
1028
1029         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1030
1031     def batches(self, split="train", nb_to_use=-1, desc=None):
1032         assert split in {"train", "test"}
1033         input = self.train_input if split == "train" else self.test_input
1034         if nb_to_use > 0:
1035             input = input[:nb_to_use]
1036         if desc is None:
1037             desc = f"epoch-{split}"
1038         for batch in tqdm.tqdm(
1039             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1040         ):
1041             yield batch
1042
1043     def vocabulary_size(self):
1044         return self.nb_codes
1045
1046     def produce_results(
1047         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1048     ):
1049         def compute_nb_correct(input):
1050             result = input.clone()
1051             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1052             ar_mask = (result != input).long()
1053             masked_inplace_autoregression(
1054                 model,
1055                 self.batch_size,
1056                 result,
1057                 ar_mask,
1058                 deterministic_synthesis,
1059                 device=self.device,
1060             )
1061
1062             errors = ((result != input).long() * ar_mask).reshape(
1063                 -1, 1 + self.nb_digits
1064             )
1065             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1066
1067             nb_total = ar_mask.max(1).values.sum()
1068             nb_correct = nb_total - errors.max(1).values.sum()
1069
1070             return nb_total, nb_correct
1071
1072         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1073
1074         logger(
1075             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1076         )
1077
1078         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1079
1080         ##############################################################
1081         # Log a few generated sequences
1082         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1083         result = input.clone()
1084         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1085         ar_mask = (result != input).long()
1086
1087         # for n in range(result.size(0)):
1088         # logger(
1089         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1090         # )
1091
1092         masked_inplace_autoregression(
1093             model,
1094             self.batch_size,
1095             result,
1096             ar_mask,
1097             deterministic_synthesis,
1098             device=self.device,
1099         )
1100
1101         for n in range(result.size(0)):
1102             logger(
1103                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1104             )
1105         ##############################################################
1106
1107
1108 ######################################################################
1109
1110 import rpl
1111
1112
1113 class RPL(Task):
1114     def tensorize(self, sequences):
1115         len_max = max([len(x) for x in sequences])
1116         return torch.cat(
1117             [
1118                 torch.tensor(
1119                     [
1120                         [
1121                             self.token2id[str(c)]
1122                             for c in s + ["<nul>"] * (len_max - len(s))
1123                         ]
1124                         for s in sequences
1125                     ]
1126                 )
1127             ],
1128             0,
1129         )
1130
1131     def seq2str(self, seq):
1132         return " ".join([self.id2token[i] for i in seq])
1133
1134     def __init__(
1135         self,
1136         nb_train_samples,
1137         nb_test_samples,
1138         batch_size,
1139         nb_starting_values=3,
1140         max_input=9,
1141         prog_len=6,
1142         nb_runs=5,
1143         no_prog=False,
1144         logger=None,
1145         device=torch.device("cpu"),
1146     ):
1147         super().__init__()
1148
1149         self.batch_size = batch_size
1150         self.device = device
1151         self.no_prog = no_prog
1152
1153         train_sequences = [
1154             rpl.generate(
1155                 nb_starting_values=nb_starting_values,
1156                 nb_result_values_max=4 * nb_starting_values,
1157                 max_input=max_input,
1158                 prog_len=prog_len,
1159                 nb_runs=nb_runs,
1160             )
1161             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1162         ]
1163
1164         test_sequences = [
1165             rpl.generate(
1166                 nb_starting_values=nb_starting_values,
1167                 nb_result_values_max=4 * nb_starting_values,
1168                 max_input=max_input,
1169                 prog_len=prog_len,
1170                 nb_runs=nb_runs,
1171             )
1172             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1173         ]
1174
1175         symbols = list(
1176             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1177         )
1178         val_max = max([x if type(x) is int else 0 for x in symbols])
1179         symbols = list(filter(lambda x: type(x) is str, symbols))
1180         symbols.sort()
1181         symbols += [str(n) for n in range(val_max + 1)]
1182         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1183         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1184
1185         self.t_nul = self.token2id["<nul>"]
1186         self.t_input = self.token2id["<in>"]
1187         self.t_output = self.token2id["<out>"]
1188         self.t_prog = self.token2id["<prg>"]
1189         self.t_end = self.token2id["<end>"]
1190
1191         self.train_input = self.tensorize(train_sequences)
1192         self.test_input = self.tensorize(test_sequences)
1193
1194         if no_prog:
1195             # Excise the program from every train and test example
1196             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1197                 None, :
1198             ]
1199             p = (
1200                 ((self.train_input == self.t_prog).long() * k)
1201                 .max(1, keepdim=True)
1202                 .values
1203             )
1204             self.train_input = (
1205                 self.train_input * (k <= p).long()
1206                 + self.t_end * (k == p + 1).long()
1207                 + self.t_nul * (k > p + 1).long()
1208             )
1209             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1210                 None, :
1211             ]
1212             p = (
1213                 ((self.test_input == self.t_prog).long() * k)
1214                 .max(1, keepdim=True)
1215                 .values
1216             )
1217             self.test_input = (
1218                 self.test_input * (k <= p).long()
1219                 + self.t_end * (k == p + 1).long()
1220                 + self.t_nul * (k > p + 1).long()
1221             )
1222
1223         if logger is not None:
1224             logger(f"value_max {val_max}")
1225             for x in self.train_input[:25]:
1226                 end = (x != self.t_nul).nonzero().max().item() + 1
1227                 seq = [self.id2token[i.item()] for i in x[:end]]
1228                 s = " ".join(seq)
1229                 logger(f"example_seq {s}")
1230
1231         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1232
1233     def batches(self, split="train", nb_to_use=-1, desc=None):
1234         assert split in {"train", "test"}
1235         input = self.train_input if split == "train" else self.test_input
1236         if nb_to_use > 0:
1237             input = input[:nb_to_use]
1238         if desc is None:
1239             desc = f"epoch-{split}"
1240         for batch in tqdm.tqdm(
1241             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1242         ):
1243             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1244             batch = batch[:, :last].to(self.device)
1245             yield batch
1246
1247     def vocabulary_size(self):
1248         return self.nb_codes
1249
1250     def produce_results(
1251         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1252     ):
1253         # --------------------------------------------------------------------
1254         def compute_nb_errors_prog(input, nb_to_log=0):
1255             result = input.clone()
1256             s = (result == self.t_prog).long()
1257             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1258             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1259
1260             masked_inplace_autoregression(
1261                 model,
1262                 self.batch_size,
1263                 result,
1264                 ar_mask,
1265                 deterministic_synthesis,
1266                 device=self.device,
1267             )
1268
1269             sum_nb_total, sum_nb_errors = 0, 0
1270             for one_input, one_result in zip(input, result):
1271                 seq = [self.id2token[i.item()] for i in one_result]
1272                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1273                 sum_nb_total += 1
1274                 sum_nb_errors += 0 if nb_errors == 0 else 1
1275                 if nb_to_log > 0:
1276                     gt_seq = [self.id2token[i.item()] for i in one_input]
1277                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1278                     gt_prog = " ".join([str(x) for x in gt_prog])
1279                     prog = " ".join([str(x) for x in prog])
1280                     comment = "*" if nb_errors == 0 else "-"
1281                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1282                     for start_stack, target_stack, result_stack, correct in stacks:
1283                         comment = "*" if correct else "-"
1284                         start_stack = " ".join([str(x) for x in start_stack])
1285                         target_stack = " ".join([str(x) for x in target_stack])
1286                         result_stack = " ".join([str(x) for x in result_stack])
1287                         logger(
1288                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1289                         )
1290                     nb_to_log -= 1
1291
1292             return sum_nb_total, sum_nb_errors
1293
1294         # --------------------------------------------------------------------
1295         def compute_nb_errors_output(input, nb_to_log=0):
1296             result = input.clone()
1297             k = torch.arange(result.size(1), device=result.device)[None, :]
1298             last_output_idx = (
1299                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1300             )
1301             first_prog_idx = (
1302                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1303             )
1304             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1305             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1306
1307             masked_inplace_autoregression(
1308                 model,
1309                 self.batch_size,
1310                 result,
1311                 ar_mask,
1312                 deterministic_synthesis,
1313                 device=self.device,
1314             )
1315
1316             sum_nb_total, sum_nb_errors = 0, 0
1317             for one_input, one_result, i, j in zip(
1318                 input, result, last_output_idx, first_prog_idx
1319             ):
1320                 seq = [self.id2token[i.item()] for i in one_result]
1321                 sum_nb_total += 1
1322                 correct = (one_input - one_result).abs().max() == 0
1323                 sum_nb_errors += 0 if correct else 1
1324                 if nb_to_log > 0:
1325                     result_stack = [
1326                         self.id2token[i.item()] for i in one_result[i : j + 1]
1327                     ]
1328                     target_stack = [
1329                         self.id2token[i.item()] for i in one_input[i : j + 1]
1330                     ]
1331                     comment = "*" if correct else "-"
1332                     result_stack = " ".join([str(x) for x in result_stack])
1333                     target_stack = " ".join([str(x) for x in target_stack])
1334                     logger(
1335                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1336                     )
1337                     nb_to_log -= 1
1338
1339             return sum_nb_total, sum_nb_errors
1340
1341         # --------------------------------------------------------------------
1342
1343         if not self.no_prog:
1344             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1345                 self.test_input[:1000].to(self.device), nb_to_log=10
1346             )
1347
1348             logger(
1349                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1350             )
1351
1352             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1353
1354         test_nb_total, test_nb_errors = compute_nb_errors_output(
1355             self.test_input[:1000].to(self.device), nb_to_log=10
1356         )
1357
1358         logger(
1359             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1360         )
1361
1362         if save_attention_image is None:
1363             logger("no save_attention_image (is pycairo installed?)")
1364         else:
1365             ns = torch.randint(self.test_input.size(0), (1,)).item()
1366             input = self.test_input[ns : ns + 1].clone()
1367             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1368             input = input[:, :last].to(self.device)
1369
1370             with torch.autograd.no_grad():
1371                 t = model.training
1372                 model.eval()
1373                 model.record_attention(True)
1374                 model(BracketedSequence(input))
1375                 model.train(t)
1376                 ram = model.retrieve_attention()
1377                 model.record_attention(False)
1378
1379             tokens_output = [self.id2token[i.item()] for i in input[0]]
1380             tokens_input = ["n/a"] + tokens_output[:-1]
1381             for n_head in range(ram[0].size(1)):
1382                 filename = os.path.join(
1383                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1384                 )
1385                 attention_matrices = [m[0, n_head] for m in ram]
1386                 save_attention_image(
1387                     filename,
1388                     tokens_input,
1389                     tokens_output,
1390                     attention_matrices,
1391                     k_top=10,
1392                     # min_total_attention=0.9,
1393                     token_gap=12,
1394                     layer_gap=50,
1395                 )
1396                 logger(f"wrote {filename}")
1397
1398
1399 ######################################################################
1400
1401
1402 import expr
1403
1404
1405 class Expr(Task):
1406     def tensorize(self, sequences):
1407         len_max = max([len(x) for x in sequences])
1408         return torch.cat(
1409             [
1410                 torch.tensor(
1411                     [
1412                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1413                         for s in sequences
1414                     ]
1415                 )
1416             ],
1417             0,
1418         ).to(self.device)
1419
1420     def __init__(
1421         self,
1422         nb_train_samples,
1423         nb_test_samples,
1424         nb_variables,
1425         sequence_length,
1426         operand_max,
1427         result_max,
1428         batch_size,
1429         device=torch.device("cpu"),
1430     ):
1431         super().__init__()
1432
1433         self.batch_size = batch_size
1434         self.device = device
1435
1436         train_sequences = expr.generate_sequences(
1437             nb_train_samples,
1438             nb_variables=nb_variables,
1439             length=sequence_length,
1440             operand_max=operand_max,
1441             result_max=result_max,
1442         )
1443
1444         test_sequences = expr.generate_sequences(
1445             nb_test_samples,
1446             nb_variables=nb_variables,
1447             length=sequence_length,
1448             operand_max=operand_max,
1449             result_max=result_max,
1450         )
1451
1452         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1453         symbols.sort()
1454
1455         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1456         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1457
1458         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1459
1460         self.train_input = self.tensorize(train_sequences)
1461         self.test_input = self.tensorize(test_sequences)
1462
1463         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1464
1465     def batches(self, split="train", nb_to_use=-1, desc=None):
1466         assert split in {"train", "test"}
1467         input = self.train_input if split == "train" else self.test_input
1468         if nb_to_use > 0:
1469             input = input[:nb_to_use]
1470         if desc is None:
1471             desc = f"epoch-{split}"
1472         for batch in tqdm.tqdm(
1473             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1474         ):
1475             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1476             batch = batch[:, :last]
1477             yield batch
1478
1479     def vocabulary_size(self):
1480         return self.nb_codes
1481
1482     def seq2str(self, s):
1483         return "".join([self.id2char[k.item()] for k in s])
1484
1485     def produce_results(
1486         self,
1487         n_epoch,
1488         model,
1489         result_dir,
1490         logger,
1491         deterministic_synthesis,
1492         input_file=None,
1493     ):
1494         def compute_nb_correct(input):
1495             result = input.clone()
1496             s = (result == self.space).long()
1497             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1498             result = (1 - ar_mask) * result + ar_mask * self.filler
1499             masked_inplace_autoregression(
1500                 model,
1501                 self.batch_size,
1502                 result,
1503                 ar_mask,
1504                 deterministic_synthesis,
1505                 device=self.device,
1506             )
1507
1508             nb_total = input.size(0)
1509             nb_correct = (input == result).long().min(1).values.sum()
1510
1511             #######################################################################
1512             # Comput predicted vs. true variable values
1513
1514             nb_delta = torch.zeros(5, dtype=torch.int64)
1515             nb_missed = 0
1516
1517             values_input = expr.extract_results([self.seq2str(s) for s in input])
1518             values_result = expr.extract_results([self.seq2str(s) for s in result])
1519
1520             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1521
1522             with open(filename, "w") as f:
1523                 for i, r in zip(values_input, values_result):
1524                     for n, vi in i.items():
1525                         vr = r.get(n)
1526                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1527
1528                         if vr is None or vr < 0:
1529                             nb_missed += 1
1530                         else:
1531                             d = abs(vr - vi)
1532                             if d >= nb_delta.size(0):
1533                                 nb_missed += 1
1534                             else:
1535                                 nb_delta[d] += 1
1536
1537             ######################################################################
1538
1539             return nb_total, nb_correct, nb_delta, nb_missed
1540
1541         (
1542             test_nb_total,
1543             test_nb_correct,
1544             test_nb_delta,
1545             test_nb_missed,
1546         ) = compute_nb_correct(self.test_input[:10000])
1547
1548         logger(
1549             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1550         )
1551
1552         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1553
1554         nb_total = test_nb_delta.sum() + test_nb_missed
1555         for d in range(test_nb_delta.size(0)):
1556             logger(
1557                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1558             )
1559         logger(
1560             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1561         )
1562
1563         ##############################################################
1564         # Log a few generated sequences
1565         if input_file is None:
1566             input = self.test_input[:10]
1567         else:
1568             with open(input_file, "r") as f:
1569                 sequences = [e.strip() for e in f.readlines()]
1570                 sequences = [s + " " + "#" * 50 for s in sequences]
1571                 input = self.tensorize(sequences)
1572
1573         result = input.clone()
1574         s = (result == self.space).long()
1575         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1576         result = (1 - ar_mask) * result + ar_mask * self.filler
1577
1578         for n in range(result.size(0)):
1579             logger(f"test_before {self.seq2str(result[n])}")
1580
1581         masked_inplace_autoregression(
1582             model,
1583             self.batch_size,
1584             result,
1585             ar_mask,
1586             deterministic_synthesis,
1587             device=self.device,
1588         )
1589
1590         correct = (1 - ar_mask) * self.space + ar_mask * input
1591         for n in range(result.size(0)):
1592             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1593             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1594             logger(f"truth       {self.seq2str(correct[n])}")
1595         ##############################################################
1596
1597
1598 ######################################################################
1599
1600 import grid
1601
1602
1603 class Grid(Task):
1604     # Make a tensor from a list of strings
1605     def str2tensor(self, descr):
1606         token_descr = [s.strip().split(" ") for s in descr]
1607         l = max([len(s) for s in token_descr])
1608         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1609         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1610         return torch.tensor(id_descr, device=self.device)
1611
1612     # Make a list of strings from a tensor
1613     def tensor2str(self, x):
1614         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1615
1616     # trim all the tensors in the tuple z to remove as much token from
1617     # left and right in the first tensor. If z is a tuple, all its
1618     # elements are trimed according to the triming for the first
1619     def trim(self, z, token="#"):
1620         n = self.token2id[token]
1621         if type(z) == tuple:
1622             x = z[0]
1623             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1624             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1625             return tuple([t[:, a:b] for t in z])
1626         else:
1627             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1628             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1629             return z[:, a:b]
1630
1631     ######################
1632
1633     def __init__(
1634         self,
1635         nb_train_samples,
1636         nb_test_samples,
1637         batch_size,
1638         size,
1639         fraction_play=0.0,
1640         logger=None,
1641         device=torch.device("cpu"),
1642     ):
1643         super().__init__()
1644
1645         self.device = device
1646         self.batch_size = batch_size
1647         self.grid_factory = grid.GridFactory(size=size)
1648         self.fraction_play = fraction_play
1649
1650         if logger is not None:
1651             logger(
1652                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1653             )
1654
1655         self.train_descr = self.grid_factory.generate_samples(
1656             nb=nb_train_samples,
1657             fraction_play=fraction_play,
1658             progress_bar=lambda r: tqdm.tqdm(r),
1659         )
1660
1661         self.test_descr = self.grid_factory.generate_samples(
1662             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1663         )
1664
1665         if fraction_play > 0:
1666             self.play_descr = self.grid_factory.generate_samples(
1667                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1668             )
1669         else:
1670             self.play_descr = []
1671
1672         # Build the tokenizer
1673         tokens = set()
1674         for d in [self.train_descr, self.test_descr, self.play_descr]:
1675             for s in d:
1676                 for t in s.strip().split(" "):
1677                     tokens.add(t)
1678         # make this set a sorted list to get the same tensors given
1679         # the same descr
1680         tokens = list(tokens)
1681         tokens.sort()
1682         tokens = ["#"] + tokens
1683         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1684         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1685         self.t_nul = self.token2id["#"]
1686         self.t_true = self.token2id["true"]
1687         self.t_false = self.token2id["false"]
1688         # self.t_pipe = self.token2id["|"]
1689
1690         # Tokenize the train and test sets
1691         self.train_input = self.str2tensor(self.train_descr)
1692         self.test_input = self.str2tensor(self.test_descr)
1693         self.play_input = (
1694             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1695         )
1696
1697     def batches(self, split="train", nb_to_use=-1, desc=None):
1698         assert split in {"train", "test"}
1699         input = self.train_input if split == "train" else self.test_input
1700         for batch in tqdm.tqdm(
1701             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1702         ):
1703             yield self.trim(batch)
1704
1705     def vocabulary_size(self):
1706         return len(self.token2id)
1707
1708     def produce_results(
1709         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1710     ):
1711         correct = self.test_input[:1000]
1712         result = correct.clone()
1713         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1714         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1715
1716         logger(f"----------------------------------------------------------")
1717
1718         for e in self.tensor2str(result[:10]):
1719             logger(f"test_before {e}")
1720
1721         masked_inplace_autoregression(
1722             model,
1723             self.batch_size,
1724             result,
1725             ar_mask,
1726             deterministic_synthesis,
1727             device=self.device,
1728         )
1729
1730         logger(f"----------------------------------------------------------")
1731
1732         for e in self.tensor2str(result[:10]):
1733             logger(f"test_after  {e}")
1734
1735         logger(f"----------------------------------------------------------")
1736
1737         nb_total = ar_mask.sum().item()
1738         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1739
1740         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1741         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1742
1743         if self.play_input is not None:
1744             result = self.play_input.clone()
1745             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1746             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1747
1748             logger(f"----------------------------------------------------------")
1749
1750             for e in self.tensor2str(result[:10]):
1751                 logger(f"play_before {e}")
1752
1753             masked_inplace_autoregression(
1754                 model,
1755                 self.batch_size,
1756                 result,
1757                 ar_mask,
1758                 deterministic_synthesis,
1759                 device=self.device,
1760             )
1761
1762             logger(f"----------------------------------------------------------")
1763
1764             for e in self.tensor2str(result[:10]):
1765                 logger(f"play_after  {e}")
1766
1767             logger(f"----------------------------------------------------------")
1768
1769
1770 ######################################################################
1771
1772 import qmlp
1773
1774
1775 class QMLP(Task):
1776     ######################
1777
1778     def __init__(
1779         self,
1780         nb_train_samples,
1781         nb_test_samples,
1782         batch_size,
1783         result_dir,
1784         logger=None,
1785         device=torch.device("cpu"),
1786     ):
1787         super().__init__()
1788
1789         self.device = device
1790         self.batch_size = batch_size
1791         self.nb_samples_per_mlp = 256
1792
1793         if logger is not None:
1794             logger(
1795                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1796             )
1797
1798         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1799             nb_mlps=nb_train_samples + nb_test_samples,
1800             nb_samples=self.nb_samples_per_mlp,
1801             device=self.device,
1802             batch_size=64,
1803             nb_epochs=250,
1804             nb_mlps_per_batch=1024,
1805         )
1806
1807         self.train_input = seq[:nb_train_samples]
1808         self.train_q_test_set = q_test_set[:nb_train_samples]
1809         self.train_ref_test_errors = test_error[:nb_train_samples]
1810         self.test_input = seq[nb_train_samples:]
1811         self.test_q_test_set = q_test_set[nb_train_samples:]
1812         self.test_ref_test_errors = test_error[nb_train_samples:]
1813
1814         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1815         with open(filename, "w") as f:
1816             for e in self.train_ref_test_errors:
1817                 f.write(f"{e}\n")
1818
1819         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1820         with open(filename, "w") as f:
1821             for e in self.test_ref_test_errors:
1822                 f.write(f"{e}\n")
1823
1824         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1825
1826     def batches(self, split="train", nb_to_use=-1, desc=None):
1827         assert split in {"train", "test"}
1828         input = self.train_input if split == "train" else self.test_input
1829         for batch in tqdm.tqdm(
1830             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1831         ):
1832             yield batch
1833
1834     def vocabulary_size(self):
1835         return self.nb_codes
1836
1837     def produce_results(
1838         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1839     ):
1840         correct = self.test_input[:1000]
1841         result = correct.clone()
1842         ar_mask = (
1843             torch.arange(result.size(1), device=result.device)
1844             > self.nb_samples_per_mlp * 3 + 1
1845         ).long()[None, :]
1846         ar_mask = ar_mask.expand_as(result)
1847         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1848
1849         masked_inplace_autoregression(
1850             model,
1851             self.batch_size,
1852             result,
1853             ar_mask,
1854             deterministic_synthesis,
1855             device=self.device,
1856         )
1857
1858         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1859         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1860         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1861
1862         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1863         with open(filename, "w") as f:
1864             for e in error_test:
1865                 f.write(f"{e}\n")
1866
1867
1868 ######################################################################
1869
1870 import greed
1871
1872
1873 class Greed(Task):
1874     def __init__(
1875         self,
1876         nb_train_samples,
1877         nb_test_samples,
1878         batch_size,
1879         height,
1880         width,
1881         T,
1882         nb_walls,
1883         nb_coins,
1884         logger=None,
1885         device=torch.device("cpu"),
1886     ):
1887         super().__init__()
1888
1889         self.batch_size = batch_size
1890         self.device = device
1891
1892         self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
1893
1894         states, actions, rewards = self.world.generate_episodes(
1895             nb_train_samples + nb_test_samples
1896         )
1897         seq = self.world.episodes2seq(states, actions, rewards)
1898         self.train_input = seq[:nb_train_samples].to(self.device)
1899         self.test_input = seq[nb_train_samples:].to(self.device)
1900
1901     def wipe_lookahead_rewards(self, batch):
1902         t = torch.arange(batch.size(1), device=batch.device)[None, :]
1903         u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
1904         lr_mask = (t <= u).long() * (
1905             t % self.world.it_len == self.world.index_lookahead_reward
1906         ).long()
1907
1908         return (
1909             lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
1910             + (1 - lr_mask) * batch
1911         )
1912
1913     def batches(self, split="train", nb_to_use=-1, desc=None):
1914         assert split in {"train", "test"}
1915         input = self.train_input if split == "train" else self.test_input
1916         if nb_to_use > 0:
1917             input = input[:nb_to_use]
1918         if desc is None:
1919             desc = f"epoch-{split}"
1920         for batch in tqdm.tqdm(
1921             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1922         ):
1923             yield self.wipe_lookahead_rewards(batch)
1924
1925     def vocabulary_size(self):
1926         return self.world.nb_codes
1927
1928     def thinking_autoregression(
1929         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
1930     ):
1931         snapshots = []
1932
1933         def ar(result, ar_mask, logit_biases=None):
1934             ar_mask = ar_mask.expand_as(result)
1935             result *= 1 - ar_mask
1936             masked_inplace_autoregression(
1937                 model,
1938                 self.batch_size,
1939                 result,
1940                 ar_mask,
1941                 deterministic_synthesis=deterministic_synthesis,
1942                 logit_biases=logit_biases,
1943                 device=self.device,
1944                 progress_bar_desc=None,
1945             )
1946             warnings.warn("keeping thinking snapshots", RuntimeWarning)
1947             snapshots.append(result[:100].detach().clone())
1948
1949         # Generate iteration after iteration
1950
1951         result = self.test_input[:250].clone()
1952         # Erase all the content but that of the first iteration
1953         result[:, self.world.it_len :] = -1
1954         # Set the lookahead_reward of the firs to UNKNOWN
1955         result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
1956             greed.REWARD_UNKNOWN
1957         )
1958
1959         t = torch.arange(result.size(1), device=result.device)[None, :]
1960
1961         for u in tqdm.tqdm(
1962             range(0, result.size(1), self.world.it_len),
1963             desc="thinking",
1964         ):
1965             # Generate the next state but keep the initial one, the
1966             # lookahead_reward of previous iterations are set to
1967             # UNKNOWN
1968             if u > 0:
1969                 result[
1970                     :, u + self.world.index_lookahead_reward
1971                 ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
1972                 ar_mask = (t >= u + self.world.index_states).long() * (
1973                     t < u + self.world.index_states + self.world.state_len
1974                 ).long()
1975                 ar(result, ar_mask)
1976
1977             # Generate the action and reward with lookahead_reward to +1
1978             result[
1979                 :, u + self.world.index_lookahead_reward
1980             ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
1981             ar_mask = (t >= u + self.world.index_reward).long() * (
1982                 t <= u + self.world.index_action
1983             ).long()
1984             ar(result, ar_mask)
1985
1986             # Set the lookahead_reward to UNKNOWN for the next iterations
1987             result[
1988                 :, u + self.world.index_lookahead_reward
1989             ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
1990
1991         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
1992         with open(filename, "w") as f:
1993             for n in range(snapshots[0].size(0)):
1994                 for s in snapshots:
1995                     lr, s, a, r = self.world.seq2episodes(
1996                         s[n : n + 1],
1997                     )
1998                     str = self.world.episodes2str(
1999                         lr, s, a, r, unicode=True, ansi_colors=True
2000                     )
2001                     f.write(str)
2002                 f.write("\n\n")
2003
2004         # Saving the generated sequences
2005
2006         lr, s, a, r = self.world.seq2episodes(result)
2007         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2008
2009         filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
2010         with open(filename, "w") as f:
2011             f.write(str)
2012             logger(f"wrote {filename}")
2013
2014     def produce_results(
2015         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
2016     ):
2017         result = self.wipe_lookahead_rewards(self.test_input[:250].clone())
2018
2019         # Saving the ground truth
2020
2021         lr, s, a, r = self.world.seq2episodes(
2022             result,
2023         )
2024         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2025
2026         filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
2027         with open(filename, "w") as f:
2028             f.write(str)
2029             logger(f"wrote {filename}")
2030
2031         # Re-generating from the first frame
2032
2033         ar_mask = (
2034             torch.arange(result.size(1), device=result.device) >= self.world.it_len
2035         ).long()[None, :]
2036         ar_mask = ar_mask.expand_as(result)
2037         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
2038
2039         masked_inplace_autoregression(
2040             model,
2041             self.batch_size,
2042             result,
2043             ar_mask,
2044             deterministic_synthesis,
2045             device=self.device,
2046         )
2047
2048         # Saving the generated sequences
2049
2050         lr, s, a, r = self.world.seq2episodes(
2051             result,
2052         )
2053         str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
2054
2055         filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
2056         with open(filename, "w") as f:
2057             f.write(str)
2058             logger(f"wrote {filename}")
2059
2060         self.thinking_autoregression(
2061             n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
2062         )
2063
2064
2065 ######################################################################