a53d213a8baa21715e68c5f3028b865cdacb3534
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train"):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 ####################
74
75 import problems
76
77
78 class SandBox(Task):
79     def __init__(
80         self,
81         problem,
82         nb_train_samples,
83         nb_test_samples,
84         batch_size,
85         logger=None,
86         device=torch.device("cpu"),
87         max_nb_codes=1024,
88     ):
89         super().__init__()
90
91         self.batch_size = batch_size
92         self.device = device
93         self.problem = problem
94
95         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
96             nb_train_samples
97         )
98         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
99             nb_test_samples
100         )
101
102         self.train_input, self.train_ar_mask = self.train_input.to(
103             device
104         ), self.train_ar_mask.to(device)
105         self.test_input, self.test_ar_mask = self.test_input.to(
106             device
107         ), self.test_ar_mask.to(device)
108
109         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
110
111         # A bit of paranoia never hurts
112         assert self.nb_codes <= max_nb_codes
113         assert self.train_input.min() >= 0
114         assert self.test_input.min() >= 0
115         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
116             (0,),
117             (1,),
118             (0, 1),
119         }
120         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
121             (0,),
122             (1,),
123             (0, 1),
124         }
125
126         if logger is not None:
127             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
128                 logger(f"train_sequences {self.problem.seq2str(s)}")
129                 a = "".join(["01"[x.item()] for x in a])
130                 logger(f"                {a}")
131
132     def batches(self, split="train", nb_to_use=-1, desc=None):
133         assert split in {"train", "test"}
134         input = self.train_input if split == "train" else self.test_input
135         if nb_to_use > 0:
136             input = input[:nb_to_use]
137         if desc is None:
138             desc = f"epoch-{split}"
139         for batch in tqdm.tqdm(
140             input.split(self.batch_size), dynamic_ncols=True, desc=desc
141         ):
142             yield batch
143
144     def vocabulary_size(self):
145         return self.nb_codes
146
147     def produce_results(
148         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
149     ):
150         def compute_accuracy(input, ar_mask, logger=None):
151             input, ar_mask = input[:nmax], ar_mask[:nmax]
152             result = input.clone() * (1 - ar_mask)
153
154             masked_inplace_autoregression(
155                 model,
156                 self.batch_size,
157                 result,
158                 ar_mask,
159                 deterministic_synthesis,
160                 progress_bar_desc=None,
161                 device=self.device,
162             )
163
164             log_ground_truth = ar_mask.min() == 0
165
166             if logger is not None:
167                 for sp, st in zip(result[:10], input[:10]):
168                     logger(
169                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
170                     )
171                     if log_ground_truth:
172                         logger(
173                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
174                         )
175
176             nb_total, nb_correct = self.problem.compute_nb_correct(
177                 input, ar_mask, result
178             )
179
180             # nb_total = ar_mask.sum().item()
181             # nb_correct = ((result == input).long() * ar_mask).sum().item()
182
183             return nb_total, nb_correct
184
185         train_nb_total, train_nb_correct = compute_accuracy(
186             self.train_input, self.train_ar_mask
187         )
188
189         logger(
190             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
191         )
192
193         test_nb_total, test_nb_correct = compute_accuracy(
194             self.test_input, self.test_ar_mask, logger
195         )
196
197         logger(
198             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
199         )
200
201         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
202
203         if save_attention_image is not None:
204             for k in range(10):
205                 ns = torch.randint(self.test_input.size(0), (1,)).item()
206                 input = self.test_input[ns : ns + 1].clone()
207
208                 with torch.autograd.no_grad():
209                     t = model.training
210                     model.eval()
211                     # model.record_attention(True)
212                     model(BracketedSequence(input))
213                     model.train(t)
214                     # ram = model.retrieve_attention()
215                     # model.record_attention(False)
216
217                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
218                 # tokens_input = ["n/a"] + tokens_output[:-1]
219                 # for n_head in range(ram[0].size(1)):
220                 # filename = os.path.join(
221                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
222                 # )
223                 # attention_matrices = [m[0, n_head] for m in ram]
224                 # save_attention_image(
225                 # filename,
226                 # tokens_input,
227                 # tokens_output,
228                 # attention_matrices,
229                 # k_top=10,
230                 ##min_total_attention=0.9,
231                 # token_gap=12,
232                 # layer_gap=50,
233                 # )
234                 # logger(f"wrote {filename}")
235
236
237 ######################################################################
238
239 import picoclvr
240
241
242 class PicoCLVR(Task):
243     # Make a tensor from a list of strings
244     def tensorize(self, descr):
245         token_descr = [s.strip().split(" ") for s in descr]
246         l = max([len(s) for s in token_descr])
247         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
248         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
249         return torch.tensor(id_descr, device=self.device)
250
251     # Make a list of strings from a tensor
252     def detensorize(self, x):
253         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
254
255     # trim all the tensors in the tuple z to remove as much token from
256     # left and right in the first tensor. If z is a tuple, all its
257     # elements are trimed according to the triming for the first
258     def trim(self, z, token="<nul>"):
259         n = self.token2id[token]
260         if type(z) == tuple:
261             x = z[0]
262             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
263             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
264             return tuple([t[:, a:b] for t in z])
265         else:
266             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
267             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
268             return z[:, a:b]
269
270     ######################
271
272     def __init__(
273         self,
274         nb_train_samples,
275         nb_test_samples,
276         batch_size,
277         height,
278         width,
279         nb_colors=5,
280         logger=None,
281         device=torch.device("cpu"),
282         pruner_train=None,
283         pruner_eval=None,
284     ):
285         super().__init__()
286
287         def generate_descr(nb, cache_suffix, pruner):
288             return picoclvr.generate(
289                 nb,
290                 height=self.height,
291                 width=self.width,
292                 nb_colors=nb_colors,
293                 pruner=pruner,
294             )
295
296         self.height = height
297         self.width = width
298         self.batch_size = batch_size
299         self.device = device
300         self.pruner_train = pruner_train
301         self.pruner_eval = pruner_eval
302
303         if logger is not None:
304             logger(
305                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
306             )
307
308         self.train_descr = generate_descr(
309             nb_train_samples, "train", pruner=self.pruner_train
310         )
311         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
312
313         # Build the tokenizer
314         tokens = {"<nul>", "<img>"}
315         for d in [self.train_descr, self.test_descr]:
316             for s in d:
317                 for t in s.strip().split(" "):
318                     tokens.add(t)
319         # make this set a sorted list to get the same tensors given
320         # the same descr
321         tokens = list(tokens)
322         tokens.sort()
323         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
324         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
325         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
326
327         # Tokenize the train and test sets
328         self.train_input = self.tensorize(self.train_descr)
329         self.test_input = self.tensorize(self.test_descr)
330
331     def batches(self, split="train"):
332         assert split in {"train", "test"}
333         input = self.train_input if split == "train" else self.test_input
334         for batch in tqdm.tqdm(
335             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
336         ):
337             yield self.trim(batch)
338
339     def vocabulary_size(self):
340         return len(self.token2id)
341
342     def compute_missing_properties(
343         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
344     ):
345         acc_nb_requested_properties = []
346         acc_nb_missing_properties = []
347         acc_nb_results = 0
348
349         for input in tqdm.tqdm(
350             self.test_input.split(self.batch_size),
351             dynamic_ncols=True,
352             desc=f"test-properties",
353         ):
354             result = input.clone()
355             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
356             result = (1 - ar_mask) * result + ar_mask * self.t_nul
357             masked_inplace_autoregression(
358                 model,
359                 self.batch_size,
360                 result,
361                 ar_mask,
362                 deterministic_synthesis,
363                 progress_bar_desc=None,
364                 device=self.device,
365             )
366
367             result_descr = self.detensorize(result)
368             np = picoclvr.nb_properties(
369                 result_descr,
370                 height=self.height,
371                 width=self.width,
372                 pruner=pruner,
373             )
374             nb_requested_properties, _, nb_missing_properties = zip(*np)
375             acc_nb_requested_properties += nb_requested_properties
376             acc_nb_missing_properties += nb_missing_properties
377             acc_nb_results += len(result_descr)
378
379         nb_requested_properties = sum(acc_nb_requested_properties)
380         nb_missing_properties = sum(acc_nb_missing_properties)
381
382         prefix = "" if pruner is None else "pruned_"
383         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
384         logger(
385             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
386         )
387         logger(
388             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
389         )
390
391         logger(
392             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
393         )
394
395     ######################################################################
396
397     def produce_results(
398         self, n_epoch, model, result_dir, logger, deterministic_synthesis
399     ):
400         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
401
402         if self.pruner_eval is not None:
403             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
404
405         nb_tokens_to_generate = self.height * self.width + 3
406         result_descr = []
407         nb_per_primer = 8
408         primer = []
409
410         for primer_descr in [
411             "red above green <sep> green top <sep> blue right of red",
412             "there is red <sep> there is yellow <sep> there is blue",
413             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
414             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
415         ]:
416             primer += [primer_descr + " <img>"] * nb_per_primer
417
418         result = self.tensorize(primer)
419         fill = result.new_full(
420             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
421         )
422         result = torch.cat((result, fill), 1)
423         ar_mask = (result == self.t_nul).long()
424         masked_inplace_autoregression(
425             model,
426             self.batch_size,
427             result,
428             ar_mask,
429             deterministic_synthesis,
430             device=self.device,
431         )
432         result_descr = self.detensorize(result)
433
434         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
435
436         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
437         acc_nb_results = len(result_descr)
438
439         nb_requested_properties = sum(acc_nb_requested_properties)
440         nb_missing_properties = sum(acc_nb_missing_properties)
441
442         prefix = "demo_"
443         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
444         logger(
445             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
446         )
447         logger(
448             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
449         )
450
451         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
452
453         if img.dim() == 5:
454             if img.size(1) == 1:
455                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
456             else:
457                 img = torch.cat(
458                     [
459                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
460                         for x in img
461                     ],
462                     0,
463                 )
464
465         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
466         torchvision.utils.save_image(
467             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
468         )
469         logger(f"wrote {image_name}")
470
471
472 ######################################################################
473
474
475 class MNIST(Task):
476     def __init__(
477         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
478     ):
479         super().__init__()
480
481         self.nb_train_samples = (nb_train_samples,)
482         self.nb_test_samples = (nb_test_samples,)
483         self.batch_size = batch_size
484         self.device = device
485         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
486         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
487         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
488         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
489
490     def batches(self, split="train", nb_to_use=-1, desc=None):
491         assert split in {"train", "test"}
492         input = self.train_input if split == "train" else self.test_input
493         if nb_to_use > 0:
494             input = input[:nb_to_use]
495         if desc is None:
496             desc = f"epoch-{split}"
497         for batch in tqdm.tqdm(
498             input.split(self.batch_size), dynamic_ncols=True, desc=desc
499         ):
500             yield batch
501
502     def vocabulary_size(self):
503         return 256
504
505     def produce_results(
506         self, n_epoch, model, result_dir, logger, deterministic_synthesis
507     ):
508         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
509         ar_mask = torch.full_like(results, 1)
510         masked_inplace_autoregression(
511             model,
512             self.batch_size,
513             results,
514             ar_mask,
515             deterministic_synthesis,
516             device=self.device,
517         )
518         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
519         torchvision.utils.save_image(
520             1 - results.reshape(-1, 1, 28, 28) / 255.0,
521             image_name,
522             nrow=16,
523             pad_value=0.8,
524         )
525         logger(f"wrote {image_name}")
526
527
528 ######################################################################
529
530 import maze
531
532
533 class Maze(Task):
534     def map2seq(self, *m):
535         return torch.cat([x.flatten(1) for x in m], 1)
536
537     def seq2map(self, s):
538         s = s.reshape(s.size(0), -1, self.height, self.width)
539         return (s[:, k] for k in range(s.size(1)))
540
541     def __init__(
542         self,
543         nb_train_samples,
544         nb_test_samples,
545         batch_size,
546         height,
547         width,
548         nb_walls,
549         device=torch.device("cpu"),
550     ):
551         super().__init__()
552
553         self.batch_size = batch_size
554         self.height = height
555         self.width = width
556         self.device = device
557
558         train_mazes, train_paths, _ = maze.create_maze_data(
559             nb_train_samples,
560             height=height,
561             width=width,
562             nb_walls=nb_walls,
563             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
564         )
565         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
566
567         test_mazes, test_paths, _ = maze.create_maze_data(
568             nb_test_samples,
569             height=height,
570             width=width,
571             nb_walls=nb_walls,
572             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
573         )
574         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
575
576         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
577
578     def batches(self, split="train", nb_to_use=-1, desc=None):
579         assert split in {"train", "test"}
580         input = self.train_input if split == "train" else self.test_input
581         if nb_to_use > 0:
582             input = input[:nb_to_use]
583         if desc is None:
584             desc = f"epoch-{split}"
585         for batch in tqdm.tqdm(
586             input.split(self.batch_size), dynamic_ncols=True, desc=desc
587         ):
588             yield batch
589
590     def vocabulary_size(self):
591         return self.nb_codes
592
593     def compute_error(
594         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
595     ):
596         nb_total, nb_correct = 0, 0
597         count = torch.zeros(
598             self.width * self.height,
599             self.width * self.height,
600             device=self.device,
601             dtype=torch.int64,
602         )
603
604         for input in self.batches(split, nb_to_use):
605             result = input.clone()
606             ar_mask = result.new_zeros(result.size())
607             ar_mask[:, self.height * self.width :] = 1
608             result *= 1 - ar_mask
609             masked_inplace_autoregression(
610                 model,
611                 self.batch_size,
612                 result,
613                 ar_mask,
614                 deterministic_synthesis,
615                 progress_bar_desc=None,
616                 device=self.device,
617             )
618             mazes, paths = self.seq2map(result)
619             path_correctness = maze.path_correctness(mazes, paths)
620             nb_correct += path_correctness.long().sum()
621             nb_total += mazes.size(0)
622
623             optimal_path_lengths = (
624                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
625             )
626             predicted_path_lengths = (
627                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
628             )
629             optimal_path_lengths = optimal_path_lengths[path_correctness]
630             predicted_path_lengths = predicted_path_lengths[path_correctness]
631             count[optimal_path_lengths, predicted_path_lengths] += 1
632
633         if count.max() == 0:
634             count = None
635         else:
636             count = count[
637                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
638             ]
639
640         return nb_total, nb_correct, count
641
642     def produce_results(
643         self, n_epoch, model, result_dir, logger, deterministic_synthesis
644     ):
645         train_nb_total, train_nb_correct, count = self.compute_error(
646             model,
647             "train",
648             nb_to_use=1000,
649             deterministic_synthesis=deterministic_synthesis,
650         )
651         logger(
652             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
653         )
654
655         test_nb_total, test_nb_correct, count = self.compute_error(
656             model,
657             "test",
658             nb_to_use=1000,
659             deterministic_synthesis=deterministic_synthesis,
660         )
661         logger(
662             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
663         )
664
665         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
666
667         if count is not None:
668             proportion_optimal = count.diagonal().sum().float() / count.sum()
669             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
670             with open(
671                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
672             ) as f:
673                 for i in range(count.size(0)):
674                     for j in range(count.size(1)):
675                         eol = " " if j < count.size(1) - 1 else "\n"
676                         f.write(f"{count[i,j]}{eol}")
677
678         input = self.test_input[:48]
679         result = input.clone()
680         ar_mask = result.new_zeros(result.size())
681         ar_mask[:, self.height * self.width :] = 1
682         result *= 1 - ar_mask
683         masked_inplace_autoregression(
684             model,
685             self.batch_size,
686             result,
687             ar_mask,
688             deterministic_synthesis,
689             device=self.device,
690         )
691
692         mazes, paths = self.seq2map(input)
693         _, predicted_paths = self.seq2map(result)
694
695         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
696         maze.save_image(
697             filename,
698             mazes=mazes,
699             target_paths=paths,
700             predicted_paths=predicted_paths,
701             path_correct=maze.path_correctness(mazes, predicted_paths),
702             path_optimal=maze.path_optimality(paths, predicted_paths),
703         )
704         logger(f"wrote {filename}")
705
706
707 ######################################################################
708
709
710 import snake
711
712
713 class Snake(Task):
714     def __init__(
715         self,
716         nb_train_samples,
717         nb_test_samples,
718         batch_size,
719         height,
720         width,
721         nb_colors,
722         length,
723         prompt_length,
724         device=torch.device("cpu"),
725     ):
726         super().__init__()
727
728         self.batch_size = batch_size
729         self.height = height
730         self.width = width
731         self.device = device
732         self.prompt_length = prompt_length
733
734         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
735             nb_train_samples,
736             height,
737             width,
738             nb_colors,
739             length,
740             prompt_length,
741             self.device,
742         )
743         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
744             nb_test_samples,
745             height,
746             width,
747             nb_colors,
748             length,
749             prompt_length,
750             self.device,
751         )
752
753         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
754
755     def batches(self, split="train", nb_to_use=-1, desc=None):
756         assert split in {"train", "test"}
757         input = self.train_input if split == "train" else self.test_input
758         if nb_to_use > 0:
759             input = input[:nb_to_use]
760         if desc is None:
761             desc = f"epoch-{split}"
762         for batch in tqdm.tqdm(
763             input.split(self.batch_size), dynamic_ncols=True, desc=desc
764         ):
765             yield batch
766
767     def vocabulary_size(self):
768         return self.nb_codes
769
770     def produce_results(
771         self, n_epoch, model, result_dir, logger, deterministic_synthesis
772     ):
773         def compute_nb_correct(input, prior_visits):
774             result = input.clone()
775             i = torch.arange(result.size(1), device=result.device)[None, :]
776             ar_mask = (
777                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
778                 .long()
779                 .expand_as(result)
780             )
781             result *= 1 - ar_mask
782
783             masked_inplace_autoregression(
784                 model,
785                 self.batch_size,
786                 result,
787                 ar_mask,
788                 deterministic_synthesis,
789                 device=self.device,
790             )
791
792             nb_total = ((prior_visits > 0) * ar_mask).sum()
793
794             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
795
796             return nb_total, nb_correct
797
798         test_nb_total, test_nb_correct = compute_nb_correct(
799             self.test_input[:1000], self.test_prior_visits[:1000]
800         )
801
802         logger(
803             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
804         )
805
806         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
807
808
809 ######################################################################
810
811
812 import stack
813
814
815 class Stack(Task):
816     def __init__(
817         self,
818         nb_train_samples,
819         nb_test_samples,
820         batch_size,
821         logger,
822         nb_steps,
823         nb_stacks,
824         nb_digits,
825         fraction_values_for_train=None,
826         device=torch.device("cpu"),
827     ):
828         super().__init__()
829
830         self.batch_size = batch_size
831         self.nb_steps = nb_steps
832         self.nb_stacks = nb_stacks
833         self.nb_digits = nb_digits
834         self.device = device
835
836         if fraction_values_for_train is None:
837             values_for_train = None
838             values_for_test = None
839         else:
840             all = torch.randperm(10**nb_digits)
841             nb_for_train = int(all.size(0) * fraction_values_for_train)
842             values_for_train = all[:nb_for_train]
843             values_for_test = all[nb_for_train:]
844
845         self.train_input, self.train_stack_counts = stack.generate_sequences(
846             nb_train_samples,
847             nb_steps,
848             nb_stacks,
849             nb_digits,
850             values_for_train,
851             self.device,
852         )
853
854         self.test_input, self.test_stack_counts = stack.generate_sequences(
855             nb_test_samples,
856             nb_steps,
857             nb_stacks,
858             nb_digits,
859             values_for_test,
860             self.device,
861         )
862
863         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
864         counts = self.test_stack_counts.flatten()[i.flatten()]
865         counts = F.one_hot(counts).sum(0)
866         logger(f"test_pop_stack_counts {counts}")
867
868         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
869
870     def batches(self, split="train", nb_to_use=-1, desc=None):
871         assert split in {"train", "test"}
872         input = self.train_input if split == "train" else self.test_input
873         if nb_to_use > 0:
874             input = input[:nb_to_use]
875         if desc is None:
876             desc = f"epoch-{split}"
877         for batch in tqdm.tqdm(
878             input.split(self.batch_size), dynamic_ncols=True, desc=desc
879         ):
880             yield batch
881
882     def vocabulary_size(self):
883         return self.nb_codes
884
885     def produce_results(
886         self, n_epoch, model, result_dir, logger, deterministic_synthesis
887     ):
888         def compute_nb_correct(input):
889             result = input.clone()
890             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
891             ar_mask = (result != input).long()
892             masked_inplace_autoregression(
893                 model,
894                 self.batch_size,
895                 result,
896                 ar_mask,
897                 deterministic_synthesis,
898                 device=self.device,
899             )
900
901             errors = ((result != input).long() * ar_mask).reshape(
902                 -1, 1 + self.nb_digits
903             )
904             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
905
906             nb_total = ar_mask.max(1).values.sum()
907             nb_correct = nb_total - errors.max(1).values.sum()
908
909             return nb_total, nb_correct
910
911         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
912
913         logger(
914             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
915         )
916
917         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
918
919         ##############################################################
920         # Log a few generated sequences
921         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
922         result = input.clone()
923         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
924         ar_mask = (result != input).long()
925
926         # for n in range(result.size(0)):
927         # logger(
928         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
929         # )
930
931         masked_inplace_autoregression(
932             model,
933             self.batch_size,
934             result,
935             ar_mask,
936             deterministic_synthesis,
937             device=self.device,
938         )
939
940         for n in range(result.size(0)):
941             logger(
942                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
943             )
944         ##############################################################
945
946
947 ######################################################################
948
949 import rpl
950
951
952 class RPL(Task):
953     def tensorize(self, sequences):
954         len_max = max([len(x) for x in sequences])
955         return torch.cat(
956             [
957                 torch.tensor(
958                     [
959                         [
960                             self.token2id[str(c)]
961                             for c in s + ["<nul>"] * (len_max - len(s))
962                         ]
963                         for s in sequences
964                     ]
965                 )
966             ],
967             0,
968         )
969
970     def seq2str(self, seq):
971         return " ".join([self.id2token[i] for i in seq])
972
973     def __init__(
974         self,
975         nb_train_samples,
976         nb_test_samples,
977         batch_size,
978         nb_starting_values=3,
979         max_input=9,
980         prog_len=6,
981         nb_runs=5,
982         no_prog=False,
983         logger=None,
984         device=torch.device("cpu"),
985     ):
986         super().__init__()
987
988         self.batch_size = batch_size
989         self.device = device
990         self.no_prog = no_prog
991
992         train_sequences = [
993             rpl.generate(
994                 nb_starting_values=nb_starting_values,
995                 nb_result_values_max=4 * nb_starting_values,
996                 max_input=max_input,
997                 prog_len=prog_len,
998                 nb_runs=nb_runs,
999             )
1000             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1001         ]
1002
1003         test_sequences = [
1004             rpl.generate(
1005                 nb_starting_values=nb_starting_values,
1006                 nb_result_values_max=4 * nb_starting_values,
1007                 max_input=max_input,
1008                 prog_len=prog_len,
1009                 nb_runs=nb_runs,
1010             )
1011             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1012         ]
1013
1014         symbols = list(
1015             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1016         )
1017         val_max = max([x if type(x) is int else 0 for x in symbols])
1018         symbols = list(filter(lambda x: type(x) is str, symbols))
1019         symbols.sort()
1020         symbols += [str(n) for n in range(val_max + 1)]
1021         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1022         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1023
1024         self.t_nul = self.token2id["<nul>"]
1025         self.t_input = self.token2id["<in>"]
1026         self.t_output = self.token2id["<out>"]
1027         self.t_prog = self.token2id["<prg>"]
1028         self.t_end = self.token2id["<end>"]
1029
1030         self.train_input = self.tensorize(train_sequences)
1031         self.test_input = self.tensorize(test_sequences)
1032
1033         if no_prog:
1034             # Excise the program from every train and test example
1035             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1036                 None, :
1037             ]
1038             p = (
1039                 ((self.train_input == self.t_prog).long() * k)
1040                 .max(1, keepdim=True)
1041                 .values
1042             )
1043             self.train_input = (
1044                 self.train_input * (k <= p).long()
1045                 + self.t_end * (k == p + 1).long()
1046                 + self.t_nul * (k > p + 1).long()
1047             )
1048             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1049                 None, :
1050             ]
1051             p = (
1052                 ((self.test_input == self.t_prog).long() * k)
1053                 .max(1, keepdim=True)
1054                 .values
1055             )
1056             self.test_input = (
1057                 self.test_input * (k <= p).long()
1058                 + self.t_end * (k == p + 1).long()
1059                 + self.t_nul * (k > p + 1).long()
1060             )
1061
1062         if logger is not None:
1063             logger(f"value_max {val_max}")
1064             for x in self.train_input[:25]:
1065                 end = (x != self.t_nul).nonzero().max().item() + 1
1066                 seq = [self.id2token[i.item()] for i in x[:end]]
1067                 s = " ".join(seq)
1068                 logger(f"example_seq {s}")
1069
1070         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1071
1072     def batches(self, split="train", nb_to_use=-1, desc=None):
1073         assert split in {"train", "test"}
1074         input = self.train_input if split == "train" else self.test_input
1075         if nb_to_use > 0:
1076             input = input[:nb_to_use]
1077         if desc is None:
1078             desc = f"epoch-{split}"
1079         for batch in tqdm.tqdm(
1080             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1081         ):
1082             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1083             batch = batch[:, :last].to(self.device)
1084             yield batch
1085
1086     def vocabulary_size(self):
1087         return self.nb_codes
1088
1089     def produce_results(
1090         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1091     ):
1092         # --------------------------------------------------------------------
1093         def compute_nb_errors_prog(input, nb_to_log=0):
1094             result = input.clone()
1095             s = (result == self.t_prog).long()
1096             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1097             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1098
1099             masked_inplace_autoregression(
1100                 model,
1101                 self.batch_size,
1102                 result,
1103                 ar_mask,
1104                 deterministic_synthesis,
1105                 device=self.device,
1106             )
1107
1108             sum_nb_total, sum_nb_errors = 0, 0
1109             for one_input, one_result in zip(input, result):
1110                 seq = [self.id2token[i.item()] for i in one_result]
1111                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1112                 sum_nb_total += 1
1113                 sum_nb_errors += 0 if nb_errors == 0 else 1
1114                 if nb_to_log > 0:
1115                     gt_seq = [self.id2token[i.item()] for i in one_input]
1116                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1117                     gt_prog = " ".join([str(x) for x in gt_prog])
1118                     prog = " ".join([str(x) for x in prog])
1119                     comment = "*" if nb_errors == 0 else "-"
1120                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1121                     for start_stack, target_stack, result_stack, correct in stacks:
1122                         comment = "*" if correct else "-"
1123                         start_stack = " ".join([str(x) for x in start_stack])
1124                         target_stack = " ".join([str(x) for x in target_stack])
1125                         result_stack = " ".join([str(x) for x in result_stack])
1126                         logger(
1127                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1128                         )
1129                     nb_to_log -= 1
1130
1131             return sum_nb_total, sum_nb_errors
1132
1133         # --------------------------------------------------------------------
1134         def compute_nb_errors_output(input, nb_to_log=0):
1135             result = input.clone()
1136             k = torch.arange(result.size(1), device=result.device)[None, :]
1137             last_output_idx = (
1138                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1139             )
1140             first_prog_idx = (
1141                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1142             )
1143             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1144             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1145
1146             masked_inplace_autoregression(
1147                 model,
1148                 self.batch_size,
1149                 result,
1150                 ar_mask,
1151                 deterministic_synthesis,
1152                 device=self.device,
1153             )
1154
1155             sum_nb_total, sum_nb_errors = 0, 0
1156             for one_input, one_result, i, j in zip(
1157                 input, result, last_output_idx, first_prog_idx
1158             ):
1159                 seq = [self.id2token[i.item()] for i in one_result]
1160                 sum_nb_total += 1
1161                 correct = (one_input - one_result).abs().max() == 0
1162                 sum_nb_errors += 0 if correct else 1
1163                 if nb_to_log > 0:
1164                     result_stack = [
1165                         self.id2token[i.item()] for i in one_result[i : j + 1]
1166                     ]
1167                     target_stack = [
1168                         self.id2token[i.item()] for i in one_input[i : j + 1]
1169                     ]
1170                     comment = "*" if correct else "-"
1171                     result_stack = " ".join([str(x) for x in result_stack])
1172                     target_stack = " ".join([str(x) for x in target_stack])
1173                     logger(
1174                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1175                     )
1176                     nb_to_log -= 1
1177
1178             return sum_nb_total, sum_nb_errors
1179
1180         # --------------------------------------------------------------------
1181
1182         if not self.no_prog:
1183             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1184                 self.test_input[:1000].to(self.device), nb_to_log=10
1185             )
1186
1187             logger(
1188                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1189             )
1190
1191             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1192
1193         test_nb_total, test_nb_errors = compute_nb_errors_output(
1194             self.test_input[:1000].to(self.device), nb_to_log=10
1195         )
1196
1197         logger(
1198             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1199         )
1200
1201         if save_attention_image is None:
1202             logger("no save_attention_image (is pycairo installed?)")
1203         else:
1204             ns = torch.randint(self.test_input.size(0), (1,)).item()
1205             input = self.test_input[ns : ns + 1].clone()
1206             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1207             input = input[:, :last].to(self.device)
1208
1209             with torch.autograd.no_grad():
1210                 t = model.training
1211                 model.eval()
1212                 model.record_attention(True)
1213                 model(BracketedSequence(input))
1214                 model.train(t)
1215                 ram = model.retrieve_attention()
1216                 model.record_attention(False)
1217
1218             tokens_output = [self.id2token[i.item()] for i in input[0]]
1219             tokens_input = ["n/a"] + tokens_output[:-1]
1220             for n_head in range(ram[0].size(1)):
1221                 filename = os.path.join(
1222                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1223                 )
1224                 attention_matrices = [m[0, n_head] for m in ram]
1225                 save_attention_image(
1226                     filename,
1227                     tokens_input,
1228                     tokens_output,
1229                     attention_matrices,
1230                     k_top=10,
1231                     # min_total_attention=0.9,
1232                     token_gap=12,
1233                     layer_gap=50,
1234                 )
1235                 logger(f"wrote {filename}")
1236
1237
1238 ######################################################################
1239
1240
1241 import expr
1242
1243
1244 class Expr(Task):
1245     def tensorize(self, sequences):
1246         len_max = max([len(x) for x in sequences])
1247         return torch.cat(
1248             [
1249                 torch.tensor(
1250                     [
1251                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1252                         for s in sequences
1253                     ]
1254                 )
1255             ],
1256             0,
1257         ).to(self.device)
1258
1259     def __init__(
1260         self,
1261         nb_train_samples,
1262         nb_test_samples,
1263         nb_variables,
1264         sequence_length,
1265         operand_max,
1266         result_max,
1267         batch_size,
1268         device=torch.device("cpu"),
1269     ):
1270         super().__init__()
1271
1272         self.batch_size = batch_size
1273         self.device = device
1274
1275         train_sequences = expr.generate_sequences(
1276             nb_train_samples,
1277             nb_variables=nb_variables,
1278             length=sequence_length,
1279             operand_max=operand_max,
1280             result_max=result_max,
1281         )
1282
1283         test_sequences = expr.generate_sequences(
1284             nb_test_samples,
1285             nb_variables=nb_variables,
1286             length=sequence_length,
1287             operand_max=operand_max,
1288             result_max=result_max,
1289         )
1290
1291         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1292         symbols.sort()
1293
1294         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1295         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1296
1297         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1298
1299         self.train_input = self.tensorize(train_sequences)
1300         self.test_input = self.tensorize(test_sequences)
1301
1302         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1303
1304     def batches(self, split="train", nb_to_use=-1, desc=None):
1305         assert split in {"train", "test"}
1306         input = self.train_input if split == "train" else self.test_input
1307         if nb_to_use > 0:
1308             input = input[:nb_to_use]
1309         if desc is None:
1310             desc = f"epoch-{split}"
1311         for batch in tqdm.tqdm(
1312             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1313         ):
1314             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1315             batch = batch[:, :last]
1316             yield batch
1317
1318     def vocabulary_size(self):
1319         return self.nb_codes
1320
1321     def seq2str(self, s):
1322         return "".join([self.id2char[k.item()] for k in s])
1323
1324     def produce_results(
1325         self,
1326         n_epoch,
1327         model,
1328         result_dir,
1329         logger,
1330         deterministic_synthesis,
1331         input_file=None,
1332     ):
1333         def compute_nb_correct(input):
1334             result = input.clone()
1335             s = (result == self.space).long()
1336             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1337             result = (1 - ar_mask) * result + ar_mask * self.filler
1338             masked_inplace_autoregression(
1339                 model,
1340                 self.batch_size,
1341                 result,
1342                 ar_mask,
1343                 deterministic_synthesis,
1344                 device=self.device,
1345             )
1346
1347             nb_total = input.size(0)
1348             nb_correct = (input == result).long().min(1).values.sum()
1349
1350             #######################################################################
1351             # Comput predicted vs. true variable values
1352
1353             nb_delta = torch.zeros(5, dtype=torch.int64)
1354             nb_missed = 0
1355
1356             values_input = expr.extract_results([self.seq2str(s) for s in input])
1357             values_result = expr.extract_results([self.seq2str(s) for s in result])
1358
1359             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1360
1361             with open(filename, "w") as f:
1362                 for i, r in zip(values_input, values_result):
1363                     for n, vi in i.items():
1364                         vr = r.get(n)
1365                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1366
1367                         if vr is None or vr < 0:
1368                             nb_missed += 1
1369                         else:
1370                             d = abs(vr - vi)
1371                             if d >= nb_delta.size(0):
1372                                 nb_missed += 1
1373                             else:
1374                                 nb_delta[d] += 1
1375
1376             ######################################################################
1377
1378             return nb_total, nb_correct, nb_delta, nb_missed
1379
1380         (
1381             test_nb_total,
1382             test_nb_correct,
1383             test_nb_delta,
1384             test_nb_missed,
1385         ) = compute_nb_correct(self.test_input[:10000])
1386
1387         logger(
1388             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1389         )
1390
1391         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1392
1393         nb_total = test_nb_delta.sum() + test_nb_missed
1394         for d in range(test_nb_delta.size(0)):
1395             logger(
1396                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1397             )
1398         logger(
1399             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1400         )
1401
1402         ##############################################################
1403         # Log a few generated sequences
1404         if input_file is None:
1405             input = self.test_input[:10]
1406         else:
1407             with open(input_file, "r") as f:
1408                 sequences = [e.strip() for e in f.readlines()]
1409                 sequences = [s + " " + "#" * 50 for s in sequences]
1410                 input = self.tensorize(sequences)
1411
1412         result = input.clone()
1413         s = (result == self.space).long()
1414         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1415         result = (1 - ar_mask) * result + ar_mask * self.filler
1416
1417         for n in range(result.size(0)):
1418             logger(f"test_before {self.seq2str(result[n])}")
1419
1420         masked_inplace_autoregression(
1421             model,
1422             self.batch_size,
1423             result,
1424             ar_mask,
1425             deterministic_synthesis,
1426             device=self.device,
1427         )
1428
1429         correct = (1 - ar_mask) * self.space + ar_mask * input
1430         for n in range(result.size(0)):
1431             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1432             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1433             logger(f"truth       {self.seq2str(correct[n])}")
1434         ##############################################################
1435
1436
1437 ######################################################################
1438
1439 import grid
1440
1441
1442 class Grid(Task):
1443     # Make a tensor from a list of strings
1444     def str2tensor(self, descr):
1445         token_descr = [s.strip().split(" ") for s in descr]
1446         l = max([len(s) for s in token_descr])
1447         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1448         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1449         return torch.tensor(id_descr, device=self.device)
1450
1451     # Make a list of strings from a tensor
1452     def tensor2str(self, x):
1453         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1454
1455     # trim all the tensors in the tuple z to remove as much token from
1456     # left and right in the first tensor. If z is a tuple, all its
1457     # elements are trimed according to the triming for the first
1458     def trim(self, z, token="#"):
1459         n = self.token2id[token]
1460         if type(z) == tuple:
1461             x = z[0]
1462             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1463             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1464             return tuple([t[:, a:b] for t in z])
1465         else:
1466             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1467             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1468             return z[:, a:b]
1469
1470     ######################
1471
1472     def __init__(
1473         self,
1474         nb_train_samples,
1475         nb_test_samples,
1476         batch_size,
1477         size,
1478         logger=None,
1479         device=torch.device("cpu"),
1480     ):
1481         super().__init__()
1482
1483         self.device = device
1484         self.batch_size = batch_size
1485         self.grid_factory = grid.GridFactory(size=size)
1486
1487         if logger is not None:
1488             logger(
1489                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1490             )
1491
1492         self.train_descr = self.grid_factory.generate_samples(
1493             nb_train_samples, lambda r: tqdm.tqdm(r)
1494         )
1495         self.test_descr = self.grid_factory.generate_samples(
1496             nb_test_samples, lambda r: tqdm.tqdm(r)
1497         )
1498
1499         # Build the tokenizer
1500         tokens = set()
1501         for d in [self.train_descr, self.test_descr]:
1502             for s in d:
1503                 for t in s.strip().split(" "):
1504                     tokens.add(t)
1505         # make this set a sorted list to get the same tensors given
1506         # the same descr
1507         tokens = list(tokens)
1508         tokens.sort()
1509         tokens = ["#"] + tokens
1510         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1511         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1512         self.t_nul = self.token2id["#"]
1513         self.t_true = self.token2id["true"]
1514         self.t_false = self.token2id["false"]
1515
1516         # Tokenize the train and test sets
1517         self.train_input = self.str2tensor(self.train_descr)
1518         self.test_input = self.str2tensor(self.test_descr)
1519
1520     def batches(self, split="train"):
1521         assert split in {"train", "test"}
1522         input = self.train_input if split == "train" else self.test_input
1523         for batch in tqdm.tqdm(
1524             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1525         ):
1526             yield self.trim(batch)
1527
1528     def vocabulary_size(self):
1529         return len(self.token2id)
1530
1531     def produce_results(
1532         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1533     ):
1534         correct = self.test_input[:1000]
1535         result = correct.clone()
1536         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1537         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1538
1539         logger(f"----------------------------------------------------------")
1540
1541         for e in self.tensor2str(result[:10]):
1542             logger(f"test_before {e}")
1543
1544         masked_inplace_autoregression(
1545             model,
1546             self.batch_size,
1547             result,
1548             ar_mask,
1549             deterministic_synthesis,
1550             device=self.device,
1551         )
1552
1553         logger(f"----------------------------------------------------------")
1554
1555         for e in self.tensor2str(result[:10]):
1556             logger(f"test_after  {e}")
1557
1558         logger(f"----------------------------------------------------------")
1559
1560         nb_total = ar_mask.sum().item()
1561         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1562
1563         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1564         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1565
1566         if n_epoch == 5 or n_epoch == 10 or n_epoch == 20:
1567             if save_attention_image is None:
1568                 logger("no save_attention_image (is pycairo installed?)")
1569             else:
1570                 for k in range(10):
1571                     ns = k  # torch.randint(self.test_input.size(0), (1,)).item()
1572                     input = self.test_input[ns : ns + 1].clone()
1573                     with torch.autograd.no_grad():
1574                         t = model.training
1575                         model.eval()
1576                         model.record_attention(True)
1577                         model(BracketedSequence(input))
1578                         model.train(t)
1579                         ram = model.retrieve_attention()
1580                         model.record_attention(False)
1581
1582                     tokens_output = [self.id2token[t.item()] for t in input[0]]
1583                     tokens_input = ["n/a"] + tokens_output[:-1]
1584                     for n_head in range(ram[0].size(1)):
1585                         filename = os.path.join(
1586                             result_dir,
1587                             f"sandbox_attention_epoch_{n_epoch}_sample_{k}_head_{n_head}.pdf",
1588                         )
1589                         attention_matrices = [m[0, n_head] for m in ram]
1590                         save_attention_image(
1591                             filename,
1592                             tokens_input,
1593                             tokens_output,
1594                             attention_matrices,
1595                             k_top=10,
1596                             # min_total_attention=0.9,
1597                             token_gap=12,
1598                             layer_gap=50,
1599                         )
1600                         logger(f"wrote {filename}")
1601
1602
1603 ######################################################################
1604
1605 import qmlp
1606
1607
1608 class QMLP(Task):
1609     ######################
1610
1611     def __init__(
1612         self,
1613         nb_train_samples,
1614         nb_test_samples,
1615         batch_size,
1616         result_dir,
1617         logger=None,
1618         device=torch.device("cpu"),
1619     ):
1620         super().__init__()
1621
1622         self.device = device
1623         self.batch_size = batch_size
1624         self.nb_samples_per_mlp = 256
1625
1626         if logger is not None:
1627             logger(
1628                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1629             )
1630
1631         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1632             nb_mlps=nb_train_samples + nb_test_samples,
1633             nb_samples=self.nb_samples_per_mlp,
1634             device=self.device,
1635             batch_size=64,
1636             nb_epochs=250,
1637             nb_mlps_per_batch=1024,
1638         )
1639
1640         self.train_input = seq[:nb_train_samples]
1641         self.train_q_test_set = q_test_set[:nb_train_samples]
1642         self.train_ref_test_errors = test_error[:nb_train_samples]
1643         self.test_input = seq[nb_train_samples:]
1644         self.test_q_test_set = q_test_set[nb_train_samples:]
1645         self.test_ref_test_errors = test_error[nb_train_samples:]
1646
1647         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1648         with open(filename, "w") as f:
1649             for e in self.train_ref_test_errors:
1650                 f.write(f"{e}\n")
1651
1652         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1653         with open(filename, "w") as f:
1654             for e in self.test_ref_test_errors:
1655                 f.write(f"{e}\n")
1656
1657         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1658
1659     def batches(self, split="train"):
1660         assert split in {"train", "test"}
1661         input = self.train_input if split == "train" else self.test_input
1662         for batch in tqdm.tqdm(
1663             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1664         ):
1665             yield batch
1666
1667     def vocabulary_size(self):
1668         return self.nb_codes
1669
1670     def produce_results(
1671         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1672     ):
1673         correct = self.test_input[:1000]
1674         result = correct.clone()
1675         ar_mask = (
1676             torch.arange(result.size(1), device=result.device)
1677             > self.nb_samples_per_mlp * 3 + 1
1678         ).long()[None, :]
1679         ar_mask = ar_mask.expand_as(result)
1680         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1681
1682         masked_inplace_autoregression(
1683             model,
1684             self.batch_size,
1685             result,
1686             ar_mask,
1687             deterministic_synthesis,
1688             device=self.device,
1689         )
1690
1691         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1692         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1693         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1694
1695         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1696         with open(filename, "w") as f:
1697             for e in error_test:
1698                 f.write(f"{e}\n")
1699
1700
1701 ######################################################################