Update.
[mygptrnn.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train", desc=None):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 ####################
74
75 import problems
76
77
78 class SandBox(Task):
79     def __init__(
80         self,
81         problem,
82         nb_train_samples,
83         nb_test_samples,
84         batch_size,
85         logger=None,
86         device=torch.device("cpu"),
87         max_nb_codes=1024,
88     ):
89         super().__init__()
90
91         self.batch_size = batch_size
92         self.device = device
93         self.problem = problem
94
95         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
96             nb_train_samples
97         )
98         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
99             nb_test_samples
100         )
101
102         self.train_input, self.train_ar_mask = self.train_input.to(
103             device
104         ), self.train_ar_mask.to(device)
105         self.test_input, self.test_ar_mask = self.test_input.to(
106             device
107         ), self.test_ar_mask.to(device)
108
109         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
110
111         # A bit of paranoia never hurts
112         assert self.nb_codes <= max_nb_codes
113         assert self.train_input.min() >= 0
114         assert self.test_input.min() >= 0
115         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
116             (0,),
117             (1,),
118             (0, 1),
119         }
120         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
121             (0,),
122             (1,),
123             (0, 1),
124         }
125
126         if logger is not None:
127             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
128                 logger(f"train_sequences {self.problem.seq2str(s)}")
129                 a = "".join(["01"[x.item()] for x in a])
130                 logger(f"                {a}")
131
132     def batches(self, split="train", nb_to_use=-1, desc=None):
133         assert split in {"train", "test"}
134         input = self.train_input if split == "train" else self.test_input
135         if nb_to_use > 0:
136             input = input[:nb_to_use]
137         if desc is None:
138             desc = f"epoch-{split}"
139         for batch in tqdm.tqdm(
140             input.split(self.batch_size), dynamic_ncols=True, desc=desc
141         ):
142             yield batch
143
144     def vocabulary_size(self):
145         return self.nb_codes
146
147     def produce_results(
148         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
149     ):
150         def compute_accuracy(input, ar_mask, logger=None):
151             input, ar_mask = input[:nmax], ar_mask[:nmax]
152             result = input.clone() * (1 - ar_mask)
153
154             masked_inplace_autoregression(
155                 model,
156                 self.batch_size,
157                 result,
158                 ar_mask,
159                 deterministic_synthesis,
160                 progress_bar_desc=None,
161                 device=self.device,
162             )
163
164             log_ground_truth = ar_mask.min() == 0
165
166             if logger is not None:
167                 for sp, st in zip(result[:10], input[:10]):
168                     logger(
169                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
170                     )
171                     if log_ground_truth:
172                         logger(
173                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
174                         )
175
176             nb_total, nb_correct = self.problem.compute_nb_correct(
177                 input, ar_mask, result
178             )
179
180             # nb_total = ar_mask.sum().item()
181             # nb_correct = ((result == input).long() * ar_mask).sum().item()
182
183             return nb_total, nb_correct
184
185         train_nb_total, train_nb_correct = compute_accuracy(
186             self.train_input, self.train_ar_mask
187         )
188
189         logger(
190             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
191         )
192
193         test_nb_total, test_nb_correct = compute_accuracy(
194             self.test_input, self.test_ar_mask, logger
195         )
196
197         logger(
198             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
199         )
200
201         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
202
203         if save_attention_image is not None:
204             for k in range(10):
205                 ns = torch.randint(self.test_input.size(0), (1,)).item()
206                 input = self.test_input[ns : ns + 1].clone()
207
208                 with torch.autograd.no_grad():
209                     t = model.training
210                     model.eval()
211                     # model.record_attention(True)
212                     model(BracketedSequence(input))
213                     model.train(t)
214                     # ram = model.retrieve_attention()
215                     # model.record_attention(False)
216
217                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
218                 # tokens_input = ["n/a"] + tokens_output[:-1]
219                 # for n_head in range(ram[0].size(1)):
220                 # filename = os.path.join(
221                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
222                 # )
223                 # attention_matrices = [m[0, n_head] for m in ram]
224                 # save_attention_image(
225                 # filename,
226                 # tokens_input,
227                 # tokens_output,
228                 # attention_matrices,
229                 # k_top=10,
230                 ##min_total_attention=0.9,
231                 # token_gap=12,
232                 # layer_gap=50,
233                 # )
234                 # logger(f"wrote {filename}")
235
236
237 ######################################################################
238
239 import picoclvr
240
241
242 class PicoCLVR(Task):
243     # Make a tensor from a list of strings
244     def tensorize(self, descr):
245         token_descr = [s.strip().split(" ") for s in descr]
246         l = max([len(s) for s in token_descr])
247         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
248         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
249         return torch.tensor(id_descr, device=self.device)
250
251     # Make a list of strings from a tensor
252     def detensorize(self, x):
253         def id2token(t):
254             try:
255                 return self.id2token[t.item()]
256             except KeyError:
257                 return "?"
258
259         return [" ".join([id2token(t) for t in r]) for r in x]
260
261     # trim all the tensors in the tuple z to remove as much token from
262     # left and right in the first tensor. If z is a tuple, all its
263     # elements are trimed according to the triming for the first
264     def trim(self, z, token="<nul>"):
265         n = self.token2id[token]
266         if type(z) == tuple:
267             x = z[0]
268             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
269             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
270             return tuple([t[:, a:b] for t in z])
271         else:
272             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
273             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
274             return z[:, a:b]
275
276     ######################
277
278     def __init__(
279         self,
280         nb_train_samples,
281         nb_test_samples,
282         batch_size,
283         height,
284         width,
285         nb_colors=5,
286         logger=None,
287         device=torch.device("cpu"),
288         pruner_train=None,
289         pruner_eval=None,
290     ):
291         super().__init__()
292
293         def generate_descr(nb, cache_suffix, pruner):
294             return picoclvr.generate(
295                 nb,
296                 height=self.height,
297                 width=self.width,
298                 nb_colors=nb_colors,
299                 pruner=pruner,
300             )
301
302         self.height = height
303         self.width = width
304         self.batch_size = batch_size
305         self.device = device
306         self.pruner_train = pruner_train
307         self.pruner_eval = pruner_eval
308
309         if logger is not None:
310             logger(
311                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
312             )
313
314         self.train_descr = generate_descr(
315             nb_train_samples, "train", pruner=self.pruner_train
316         )
317         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
318
319         # Build the tokenizer
320         tokens = {"<nul>", "<img>"}
321         for d in [self.train_descr, self.test_descr]:
322             for s in d:
323                 for t in s.strip().split(" "):
324                     tokens.add(t)
325         # make this set a sorted list to get the same tensors given
326         # the same descr
327         tokens = list(tokens)
328         tokens.sort()
329         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
330         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
331         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
332
333         # Tokenize the train and test sets
334         self.train_input = self.tensorize(self.train_descr)
335         self.test_input = self.tensorize(self.test_descr)
336
337     def batches(self, split="train", desc=None):
338         assert split in {"train", "test"}
339         input = self.train_input if split == "train" else self.test_input
340         for batch in tqdm.tqdm(
341             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
342         ):
343             yield self.trim(batch)
344
345     def vocabulary_size(self):
346         return len(self.token2id)
347
348     def compute_missing_properties(
349         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
350     ):
351         acc_nb_requested_properties = []
352         acc_nb_missing_properties = []
353         acc_nb_results = 0
354
355         for input in tqdm.tqdm(
356             self.test_input.split(self.batch_size),
357             dynamic_ncols=True,
358             desc=f"test-properties",
359         ):
360             result = input.clone()
361             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
362             result = (1 - ar_mask) * result + ar_mask * self.t_nul
363             masked_inplace_autoregression(
364                 model,
365                 self.batch_size,
366                 result,
367                 ar_mask,
368                 deterministic_synthesis,
369                 progress_bar_desc=None,
370                 device=self.device,
371             )
372
373             result_descr = self.detensorize(result)
374             np = picoclvr.nb_properties(
375                 result_descr,
376                 height=self.height,
377                 width=self.width,
378                 pruner=pruner,
379             )
380             nb_requested_properties, _, nb_missing_properties = zip(*np)
381             acc_nb_requested_properties += nb_requested_properties
382             acc_nb_missing_properties += nb_missing_properties
383             acc_nb_results += len(result_descr)
384
385         nb_requested_properties = sum(acc_nb_requested_properties)
386         nb_missing_properties = sum(acc_nb_missing_properties)
387
388         prefix = "" if pruner is None else "pruned_"
389         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
390         logger(
391             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
392         )
393         logger(
394             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
395         )
396
397         logger(
398             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
399         )
400
401     ######################################################################
402
403     def produce_results(
404         self, n_epoch, model, result_dir, logger, deterministic_synthesis
405     ):
406         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
407
408         if self.pruner_eval is not None:
409             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
410
411         nb_tokens_to_generate = self.height * self.width + 3
412         result_descr = []
413         nb_per_primer = 8
414         primer = []
415
416         for primer_descr in [
417             "red above green <sep> green top <sep> blue right of red",
418             "there is red <sep> there is yellow <sep> there is blue",
419             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
420             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
421         ]:
422             primer += [primer_descr + " <img>"] * nb_per_primer
423
424         result = self.tensorize(primer)
425         fill = result.new_full(
426             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
427         )
428         result = torch.cat((result, fill), 1)
429         ar_mask = (result == self.t_nul).long()
430         masked_inplace_autoregression(
431             model,
432             self.batch_size,
433             result,
434             ar_mask,
435             deterministic_synthesis,
436             device=self.device,
437         )
438         result_descr = self.detensorize(result)
439
440         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
441
442         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
443         acc_nb_results = len(result_descr)
444
445         nb_requested_properties = sum(acc_nb_requested_properties)
446         nb_missing_properties = sum(acc_nb_missing_properties)
447
448         prefix = "demo_"
449         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
450         logger(
451             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
452         )
453         logger(
454             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
455         )
456
457         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
458
459         if img.dim() == 5:
460             if img.size(1) == 1:
461                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
462             else:
463                 img = torch.cat(
464                     [
465                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
466                         for x in img
467                     ],
468                     0,
469                 )
470
471         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
472         torchvision.utils.save_image(
473             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
474         )
475         logger(f"wrote {image_name}")
476
477
478 ######################################################################
479
480
481 class MNIST(Task):
482     def __init__(
483         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
484     ):
485         super().__init__()
486
487         self.nb_train_samples = (nb_train_samples,)
488         self.nb_test_samples = (nb_test_samples,)
489         self.batch_size = batch_size
490         self.device = device
491         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
492         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
493         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
494         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
495
496     def batches(self, split="train", nb_to_use=-1, desc=None):
497         assert split in {"train", "test"}
498         input = self.train_input if split == "train" else self.test_input
499         if nb_to_use > 0:
500             input = input[:nb_to_use]
501         if desc is None:
502             desc = f"epoch-{split}"
503         for batch in tqdm.tqdm(
504             input.split(self.batch_size), dynamic_ncols=True, desc=desc
505         ):
506             yield batch
507
508     def vocabulary_size(self):
509         return 256
510
511     def produce_results(
512         self, n_epoch, model, result_dir, logger, deterministic_synthesis
513     ):
514         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
515         ar_mask = torch.full_like(results, 1)
516         masked_inplace_autoregression(
517             model,
518             self.batch_size,
519             results,
520             ar_mask,
521             deterministic_synthesis,
522             device=self.device,
523         )
524         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
525         torchvision.utils.save_image(
526             1 - results.reshape(-1, 1, 28, 28) / 255.0,
527             image_name,
528             nrow=16,
529             pad_value=0.8,
530         )
531         logger(f"wrote {image_name}")
532
533
534 ######################################################################
535
536 import maze
537
538
539 class Maze(Task):
540     def map2seq(self, *m):
541         return torch.cat([x.flatten(1) for x in m], 1)
542
543     def seq2map(self, s):
544         s = s.reshape(s.size(0), -1, self.height, self.width)
545         return (s[:, k] for k in range(s.size(1)))
546
547     def __init__(
548         self,
549         nb_train_samples,
550         nb_test_samples,
551         batch_size,
552         height,
553         width,
554         nb_walls,
555         device=torch.device("cpu"),
556     ):
557         super().__init__()
558
559         self.batch_size = batch_size
560         self.height = height
561         self.width = width
562         self.device = device
563
564         train_mazes, train_paths, _ = maze.create_maze_data(
565             nb_train_samples,
566             height=height,
567             width=width,
568             nb_walls=nb_walls,
569             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
570         )
571         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
572
573         test_mazes, test_paths, _ = maze.create_maze_data(
574             nb_test_samples,
575             height=height,
576             width=width,
577             nb_walls=nb_walls,
578             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
579         )
580         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
581
582         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
583
584     def batches(self, split="train", nb_to_use=-1, desc=None):
585         assert split in {"train", "test"}
586         input = self.train_input if split == "train" else self.test_input
587         if nb_to_use > 0:
588             input = input[:nb_to_use]
589         if desc is None:
590             desc = f"epoch-{split}"
591         for batch in tqdm.tqdm(
592             input.split(self.batch_size), dynamic_ncols=True, desc=desc
593         ):
594             yield batch
595
596     def vocabulary_size(self):
597         return self.nb_codes
598
599     def compute_error(
600         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
601     ):
602         nb_total, nb_correct = 0, 0
603         count = torch.zeros(
604             self.width * self.height,
605             self.width * self.height,
606             device=self.device,
607             dtype=torch.int64,
608         )
609
610         for input in self.batches(split, nb_to_use):
611             result = input.clone()
612             ar_mask = result.new_zeros(result.size())
613             ar_mask[:, self.height * self.width :] = 1
614             result *= 1 - ar_mask
615             masked_inplace_autoregression(
616                 model,
617                 self.batch_size,
618                 result,
619                 ar_mask,
620                 deterministic_synthesis,
621                 progress_bar_desc=None,
622                 device=self.device,
623             )
624             mazes, paths = self.seq2map(result)
625             path_correctness = maze.path_correctness(mazes, paths)
626             nb_correct += path_correctness.long().sum()
627             nb_total += mazes.size(0)
628
629             optimal_path_lengths = (
630                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
631             )
632             predicted_path_lengths = (
633                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
634             )
635             optimal_path_lengths = optimal_path_lengths[path_correctness]
636             predicted_path_lengths = predicted_path_lengths[path_correctness]
637             count[optimal_path_lengths, predicted_path_lengths] += 1
638
639         if count.max() == 0:
640             count = None
641         else:
642             count = count[
643                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
644             ]
645
646         return nb_total, nb_correct, count
647
648     def produce_results(
649         self, n_epoch, model, result_dir, logger, deterministic_synthesis
650     ):
651         train_nb_total, train_nb_correct, count = self.compute_error(
652             model,
653             "train",
654             nb_to_use=1000,
655             deterministic_synthesis=deterministic_synthesis,
656         )
657         logger(
658             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
659         )
660
661         test_nb_total, test_nb_correct, count = self.compute_error(
662             model,
663             "test",
664             nb_to_use=1000,
665             deterministic_synthesis=deterministic_synthesis,
666         )
667         logger(
668             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
669         )
670
671         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
672
673         if count is not None:
674             proportion_optimal = count.diagonal().sum().float() / count.sum()
675             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
676             with open(
677                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
678             ) as f:
679                 for i in range(count.size(0)):
680                     for j in range(count.size(1)):
681                         eol = " " if j < count.size(1) - 1 else "\n"
682                         f.write(f"{count[i,j]}{eol}")
683
684         input = self.test_input[:48]
685         result = input.clone()
686         ar_mask = result.new_zeros(result.size())
687         ar_mask[:, self.height * self.width :] = 1
688         result *= 1 - ar_mask
689         masked_inplace_autoregression(
690             model,
691             self.batch_size,
692             result,
693             ar_mask,
694             deterministic_synthesis,
695             device=self.device,
696         )
697
698         mazes, paths = self.seq2map(input)
699         _, predicted_paths = self.seq2map(result)
700
701         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
702         maze.save_image(
703             filename,
704             mazes=mazes,
705             target_paths=paths,
706             predicted_paths=predicted_paths,
707             path_correct=maze.path_correctness(mazes, predicted_paths),
708             path_optimal=maze.path_optimality(paths, predicted_paths),
709         )
710         logger(f"wrote {filename}")
711
712
713 ######################################################################
714
715
716 import snake
717
718
719 class Snake(Task):
720     def __init__(
721         self,
722         nb_train_samples,
723         nb_test_samples,
724         batch_size,
725         height,
726         width,
727         nb_colors,
728         length,
729         prompt_length,
730         device=torch.device("cpu"),
731     ):
732         super().__init__()
733
734         self.batch_size = batch_size
735         self.height = height
736         self.width = width
737         self.device = device
738         self.prompt_length = prompt_length
739
740         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
741             nb_train_samples,
742             height,
743             width,
744             nb_colors,
745             length,
746             prompt_length,
747             self.device,
748         )
749         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
750             nb_test_samples,
751             height,
752             width,
753             nb_colors,
754             length,
755             prompt_length,
756             self.device,
757         )
758
759         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
760
761     def batches(self, split="train", nb_to_use=-1, desc=None):
762         assert split in {"train", "test"}
763         input = self.train_input if split == "train" else self.test_input
764         if nb_to_use > 0:
765             input = input[:nb_to_use]
766         if desc is None:
767             desc = f"epoch-{split}"
768         for batch in tqdm.tqdm(
769             input.split(self.batch_size), dynamic_ncols=True, desc=desc
770         ):
771             yield batch
772
773     def vocabulary_size(self):
774         return self.nb_codes
775
776     def produce_results(
777         self, n_epoch, model, result_dir, logger, deterministic_synthesis
778     ):
779         def compute_nb_correct(input, prior_visits):
780             result = input.clone()
781             i = torch.arange(result.size(1), device=result.device)[None, :]
782             ar_mask = (
783                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
784                 .long()
785                 .expand_as(result)
786             )
787             result *= 1 - ar_mask
788
789             masked_inplace_autoregression(
790                 model,
791                 self.batch_size,
792                 result,
793                 ar_mask,
794                 deterministic_synthesis,
795                 device=self.device,
796             )
797
798             nb_total = ((prior_visits > 0) * ar_mask).sum()
799
800             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
801
802             return nb_total, nb_correct
803
804         test_nb_total, test_nb_correct = compute_nb_correct(
805             self.test_input[:1000], self.test_prior_visits[:1000]
806         )
807
808         logger(
809             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
810         )
811
812         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
813
814
815 ######################################################################
816
817
818 import stack
819
820
821 class Stack(Task):
822     def __init__(
823         self,
824         nb_train_samples,
825         nb_test_samples,
826         batch_size,
827         logger,
828         nb_steps,
829         nb_stacks,
830         nb_digits,
831         fraction_values_for_train=None,
832         device=torch.device("cpu"),
833     ):
834         super().__init__()
835
836         self.batch_size = batch_size
837         self.nb_steps = nb_steps
838         self.nb_stacks = nb_stacks
839         self.nb_digits = nb_digits
840         self.device = device
841
842         if fraction_values_for_train is None:
843             values_for_train = None
844             values_for_test = None
845         else:
846             all = torch.randperm(10**nb_digits)
847             nb_for_train = int(all.size(0) * fraction_values_for_train)
848             values_for_train = all[:nb_for_train]
849             values_for_test = all[nb_for_train:]
850
851         self.train_input, self.train_stack_counts = stack.generate_sequences(
852             nb_train_samples,
853             nb_steps,
854             nb_stacks,
855             nb_digits,
856             values_for_train,
857             self.device,
858         )
859
860         self.test_input, self.test_stack_counts = stack.generate_sequences(
861             nb_test_samples,
862             nb_steps,
863             nb_stacks,
864             nb_digits,
865             values_for_test,
866             self.device,
867         )
868
869         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
870         counts = self.test_stack_counts.flatten()[i.flatten()]
871         counts = F.one_hot(counts).sum(0)
872         logger(f"test_pop_stack_counts {counts}")
873
874         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
875
876     def batches(self, split="train", nb_to_use=-1, desc=None):
877         assert split in {"train", "test"}
878         input = self.train_input if split == "train" else self.test_input
879         if nb_to_use > 0:
880             input = input[:nb_to_use]
881         if desc is None:
882             desc = f"epoch-{split}"
883         for batch in tqdm.tqdm(
884             input.split(self.batch_size), dynamic_ncols=True, desc=desc
885         ):
886             yield batch
887
888     def vocabulary_size(self):
889         return self.nb_codes
890
891     def produce_results(
892         self, n_epoch, model, result_dir, logger, deterministic_synthesis
893     ):
894         def compute_nb_correct(input):
895             result = input.clone()
896             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
897
898             ar_mask = (result != input).long()
899             result *= 1 - ar_mask
900
901             masked_inplace_autoregression(
902                 model,
903                 self.batch_size,
904                 result,
905                 ar_mask,
906                 deterministic_synthesis,
907                 device=self.device,
908             )
909
910             errors = ((result != input).long() * ar_mask).reshape(
911                 -1, 1 + self.nb_digits
912             )
913             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
914
915             nb_total = ar_mask.max(1).values.sum()
916             nb_correct = nb_total - errors.max(1).values.sum()
917
918             return nb_total, nb_correct
919
920         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
921
922         logger(
923             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
924         )
925
926         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
927
928         ##############################################################
929         # Log a few generated sequences
930         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
931         result = input.clone()
932         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
933         ar_mask = (result != input).long()
934
935         for n in range(result.size(0)):
936             logger(
937                 f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
938             )
939
940         result *= 1 - ar_mask
941
942         masked_inplace_autoregression(
943             model,
944             self.batch_size,
945             result,
946             ar_mask,
947             deterministic_synthesis,
948             device=self.device,
949         )
950
951         for n in range(result.size(0)):
952             logger(
953                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
954             )
955         ##############################################################
956
957
958 ######################################################################
959
960 import rpl
961
962
963 class RPL(Task):
964     def tensorize(self, sequences):
965         len_max = max([len(x) for x in sequences])
966         return torch.cat(
967             [
968                 torch.tensor(
969                     [
970                         [
971                             self.token2id[str(c)]
972                             for c in s + ["<nul>"] * (len_max - len(s))
973                         ]
974                         for s in sequences
975                     ]
976                 )
977             ],
978             0,
979         )
980
981     def seq2str(self, seq):
982         return " ".join([self.id2token[i] for i in seq])
983
984     def __init__(
985         self,
986         nb_train_samples,
987         nb_test_samples,
988         batch_size,
989         nb_starting_values=3,
990         max_input=9,
991         prog_len=6,
992         nb_runs=5,
993         no_prog=False,
994         logger=None,
995         device=torch.device("cpu"),
996     ):
997         super().__init__()
998
999         self.batch_size = batch_size
1000         self.device = device
1001         self.no_prog = no_prog
1002
1003         train_sequences = [
1004             rpl.generate(
1005                 nb_starting_values=nb_starting_values,
1006                 nb_result_values_max=4 * nb_starting_values,
1007                 max_input=max_input,
1008                 prog_len=prog_len,
1009                 nb_runs=nb_runs,
1010             )
1011             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1012         ]
1013
1014         test_sequences = [
1015             rpl.generate(
1016                 nb_starting_values=nb_starting_values,
1017                 nb_result_values_max=4 * nb_starting_values,
1018                 max_input=max_input,
1019                 prog_len=prog_len,
1020                 nb_runs=nb_runs,
1021             )
1022             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1023         ]
1024
1025         symbols = list(
1026             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1027         )
1028         val_max = max([x if type(x) is int else 0 for x in symbols])
1029         symbols = list(filter(lambda x: type(x) is str, symbols))
1030         symbols.sort()
1031         symbols += [str(n) for n in range(val_max + 1)]
1032         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1033         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1034
1035         self.t_nul = self.token2id["<nul>"]
1036         self.t_input = self.token2id["<in>"]
1037         self.t_output = self.token2id["<out>"]
1038         self.t_prog = self.token2id["<prg>"]
1039         self.t_end = self.token2id["<end>"]
1040
1041         self.train_input = self.tensorize(train_sequences)
1042         self.test_input = self.tensorize(test_sequences)
1043
1044         if no_prog:
1045             # Excise the program from every train and test example
1046             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1047                 None, :
1048             ]
1049             p = (
1050                 ((self.train_input == self.t_prog).long() * k)
1051                 .max(1, keepdim=True)
1052                 .values
1053             )
1054             self.train_input = (
1055                 self.train_input * (k <= p).long()
1056                 + self.t_end * (k == p + 1).long()
1057                 + self.t_nul * (k > p + 1).long()
1058             )
1059             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1060                 None, :
1061             ]
1062             p = (
1063                 ((self.test_input == self.t_prog).long() * k)
1064                 .max(1, keepdim=True)
1065                 .values
1066             )
1067             self.test_input = (
1068                 self.test_input * (k <= p).long()
1069                 + self.t_end * (k == p + 1).long()
1070                 + self.t_nul * (k > p + 1).long()
1071             )
1072
1073         if logger is not None:
1074             logger(f"value_max {val_max}")
1075             for x in self.train_input[:25]:
1076                 end = (x != self.t_nul).nonzero().max().item() + 1
1077                 seq = [self.id2token[i.item()] for i in x[:end]]
1078                 s = " ".join(seq)
1079                 logger(f"example_seq {s}")
1080
1081         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
1082
1083     def batches(self, split="train", nb_to_use=-1, desc=None):
1084         assert split in {"train", "test"}
1085         input = self.train_input if split == "train" else self.test_input
1086         if nb_to_use > 0:
1087             input = input[:nb_to_use]
1088         if desc is None:
1089             desc = f"epoch-{split}"
1090         for batch in tqdm.tqdm(
1091             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1092         ):
1093             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1094             batch = batch[:, :last].to(self.device)
1095             yield batch
1096
1097     def vocabulary_size(self):
1098         return self.nb_codes
1099
1100     def produce_results(
1101         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1102     ):
1103         # --------------------------------------------------------------------
1104         def compute_nb_errors_prog(input, nb_to_log=0):
1105             result = input.clone()
1106             s = (result == self.t_prog).long()
1107             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1108             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1109
1110             masked_inplace_autoregression(
1111                 model,
1112                 self.batch_size,
1113                 result,
1114                 ar_mask,
1115                 deterministic_synthesis,
1116                 device=self.device,
1117             )
1118
1119             sum_nb_total, sum_nb_errors = 0, 0
1120             for one_input, one_result in zip(input, result):
1121                 seq = [self.id2token[i.item()] for i in one_result]
1122                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1123                 sum_nb_total += 1
1124                 sum_nb_errors += 0 if nb_errors == 0 else 1
1125                 if nb_to_log > 0:
1126                     gt_seq = [self.id2token[i.item()] for i in one_input]
1127                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1128                     gt_prog = " ".join([str(x) for x in gt_prog])
1129                     prog = " ".join([str(x) for x in prog])
1130                     comment = "*" if nb_errors == 0 else "-"
1131                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1132                     for start_stack, target_stack, result_stack, correct in stacks:
1133                         comment = "*" if correct else "-"
1134                         start_stack = " ".join([str(x) for x in start_stack])
1135                         target_stack = " ".join([str(x) for x in target_stack])
1136                         result_stack = " ".join([str(x) for x in result_stack])
1137                         logger(
1138                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1139                         )
1140                     nb_to_log -= 1
1141
1142             return sum_nb_total, sum_nb_errors
1143
1144         # --------------------------------------------------------------------
1145         def compute_nb_errors_output(input, nb_to_log=0):
1146             result = input.clone()
1147             k = torch.arange(result.size(1), device=result.device)[None, :]
1148             last_output_idx = (
1149                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1150             )
1151             first_prog_idx = (
1152                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1153             )
1154             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1155             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1156
1157             masked_inplace_autoregression(
1158                 model,
1159                 self.batch_size,
1160                 result,
1161                 ar_mask,
1162                 deterministic_synthesis,
1163                 device=self.device,
1164             )
1165
1166             sum_nb_total, sum_nb_errors = 0, 0
1167             for one_input, one_result, i, j in zip(
1168                 input, result, last_output_idx, first_prog_idx
1169             ):
1170                 seq = [self.id2token[i.item()] for i in one_result]
1171                 sum_nb_total += 1
1172                 correct = (one_input - one_result).abs().max() == 0
1173                 sum_nb_errors += 0 if correct else 1
1174                 if nb_to_log > 0:
1175                     result_stack = [
1176                         self.id2token[i.item()] for i in one_result[i : j + 1]
1177                     ]
1178                     target_stack = [
1179                         self.id2token[i.item()] for i in one_input[i : j + 1]
1180                     ]
1181                     comment = "*" if correct else "-"
1182                     result_stack = " ".join([str(x) for x in result_stack])
1183                     target_stack = " ".join([str(x) for x in target_stack])
1184                     logger(
1185                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1186                     )
1187                     nb_to_log -= 1
1188
1189             return sum_nb_total, sum_nb_errors
1190
1191         # --------------------------------------------------------------------
1192
1193         if not self.no_prog:
1194             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1195                 self.test_input[:1000].to(self.device), nb_to_log=10
1196             )
1197
1198             logger(
1199                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1200             )
1201
1202             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1203
1204         test_nb_total, test_nb_errors = compute_nb_errors_output(
1205             self.test_input[:1000].to(self.device), nb_to_log=10
1206         )
1207
1208         logger(
1209             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1210         )
1211
1212         if save_attention_image is not None:
1213             ns = torch.randint(self.test_input.size(0), (1,)).item()
1214             input = self.test_input[ns : ns + 1].clone()
1215             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1216             input = input[:, :last].to(self.device)
1217
1218             with torch.autograd.no_grad():
1219                 t = model.training
1220                 model.eval()
1221                 model.record_attention(True)
1222                 model(BracketedSequence(input))
1223                 model.train(t)
1224                 ram = model.retrieve_attention()
1225                 model.record_attention(False)
1226
1227             tokens_output = [self.id2token[i.item()] for i in input[0]]
1228             tokens_input = ["n/a"] + tokens_output[:-1]
1229             for n_head in range(ram[0].size(1)):
1230                 filename = os.path.join(
1231                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1232                 )
1233                 attention_matrices = [m[0, n_head] for m in ram]
1234                 save_attention_image(
1235                     filename,
1236                     tokens_input,
1237                     tokens_output,
1238                     attention_matrices,
1239                     k_top=10,
1240                     # min_total_attention=0.9,
1241                     token_gap=12,
1242                     layer_gap=50,
1243                 )
1244                 logger(f"wrote {filename}")
1245
1246
1247 ######################################################################
1248
1249
1250 import expr
1251
1252
1253 class Expr(Task):
1254     def tensorize(self, sequences):
1255         len_max = max([len(x) for x in sequences])
1256         return torch.cat(
1257             [
1258                 torch.tensor(
1259                     [
1260                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1261                         for s in sequences
1262                     ]
1263                 )
1264             ],
1265             0,
1266         ).to(self.device)
1267
1268     def __init__(
1269         self,
1270         nb_train_samples,
1271         nb_test_samples,
1272         nb_variables,
1273         sequence_length,
1274         operand_max,
1275         result_max,
1276         batch_size,
1277         device=torch.device("cpu"),
1278     ):
1279         super().__init__()
1280
1281         self.batch_size = batch_size
1282         self.device = device
1283
1284         train_sequences = expr.generate_sequences(
1285             nb_train_samples,
1286             nb_variables=nb_variables,
1287             length=sequence_length,
1288             operand_max=operand_max,
1289             result_max=result_max,
1290         )
1291
1292         test_sequences = expr.generate_sequences(
1293             nb_test_samples,
1294             nb_variables=nb_variables,
1295             length=sequence_length,
1296             operand_max=operand_max,
1297             result_max=result_max,
1298         )
1299
1300         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1301         symbols.sort()
1302
1303         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1304         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1305
1306         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1307
1308         self.train_input = self.tensorize(train_sequences)
1309         self.test_input = self.tensorize(test_sequences)
1310
1311         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
1312
1313     def batches(self, split="train", nb_to_use=-1, desc=None):
1314         assert split in {"train", "test"}
1315         input = self.train_input if split == "train" else self.test_input
1316         if nb_to_use > 0:
1317             input = input[:nb_to_use]
1318         if desc is None:
1319             desc = f"epoch-{split}"
1320         for batch in tqdm.tqdm(
1321             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1322         ):
1323             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1324             batch = batch[:, :last]
1325             yield batch
1326
1327     def vocabulary_size(self):
1328         return self.nb_codes
1329
1330     def seq2str(self, s):
1331         return "".join([self.id2char[k.item()] for k in s])
1332
1333     def produce_results(
1334         self,
1335         n_epoch,
1336         model,
1337         result_dir,
1338         logger,
1339         deterministic_synthesis,
1340         input_file=None,
1341     ):
1342         def compute_nb_correct(input):
1343             result = input.clone()
1344             s = (result == self.space).long()
1345             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1346             result = (1 - ar_mask) * result + ar_mask * self.filler
1347             masked_inplace_autoregression(
1348                 model,
1349                 self.batch_size,
1350                 result,
1351                 ar_mask,
1352                 deterministic_synthesis,
1353                 device=self.device,
1354             )
1355
1356             nb_total = input.size(0)
1357             nb_correct = (input == result).long().min(1).values.sum()
1358
1359             #######################################################################
1360             # Comput predicted vs. true variable values
1361
1362             nb_delta = torch.zeros(5, dtype=torch.int64)
1363             nb_missed = 0
1364
1365             values_input = expr.extract_results([self.seq2str(s) for s in input])
1366             values_result = expr.extract_results([self.seq2str(s) for s in result])
1367
1368             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1369
1370             with open(filename, "w") as f:
1371                 for i, r in zip(values_input, values_result):
1372                     for n, vi in i.items():
1373                         vr = r.get(n)
1374                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1375
1376                         if vr is None or vr < 0:
1377                             nb_missed += 1
1378                         else:
1379                             d = abs(vr - vi)
1380                             if d >= nb_delta.size(0):
1381                                 nb_missed += 1
1382                             else:
1383                                 nb_delta[d] += 1
1384
1385             ######################################################################
1386
1387             return nb_total, nb_correct, nb_delta, nb_missed
1388
1389         (
1390             test_nb_total,
1391             test_nb_correct,
1392             test_nb_delta,
1393             test_nb_missed,
1394         ) = compute_nb_correct(self.test_input[:10000])
1395
1396         logger(
1397             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1398         )
1399
1400         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1401
1402         nb_total = test_nb_delta.sum() + test_nb_missed
1403         for d in range(test_nb_delta.size(0)):
1404             logger(
1405                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1406             )
1407         logger(
1408             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1409         )
1410
1411         ##############################################################
1412         # Log a few generated sequences
1413         if input_file is None:
1414             input = self.test_input[:10]
1415         else:
1416             with open(input_file, "r") as f:
1417                 sequences = [e.strip() for e in f.readlines()]
1418                 sequences = [s + " " + "#" * 50 for s in sequences]
1419                 input = self.tensorize(sequences)
1420
1421         result = input.clone()
1422         s = (result == self.space).long()
1423         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1424         result = (1 - ar_mask) * result + ar_mask * self.filler
1425
1426         for n in range(result.size(0)):
1427             logger(f"test_before {self.seq2str(result[n])}")
1428
1429         masked_inplace_autoregression(
1430             model,
1431             self.batch_size,
1432             result,
1433             ar_mask,
1434             deterministic_synthesis,
1435             device=self.device,
1436         )
1437
1438         correct = (1 - ar_mask) * self.space + ar_mask * input
1439         for n in range(result.size(0)):
1440             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1441             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1442             logger(f"truth       {self.seq2str(correct[n])}")
1443         ##############################################################
1444
1445
1446 ######################################################################
1447
1448 import grid
1449
1450
1451 class Grid(Task):
1452     # Make a tensor from a list of strings
1453     def str2tensor(self, descr):
1454         token_descr = [s.strip().split(" ") for s in descr]
1455         l = max([len(s) for s in token_descr])
1456         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1457         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1458         return torch.tensor(id_descr, device=self.device)
1459
1460     # Make a list of strings from a tensor
1461     def tensor2str(self, x):
1462         def id2token(t):
1463             try:
1464                 return self.id2token[t.item()]
1465             except KeyError:
1466                 return "?"
1467
1468         return [" ".join([id2token(t) for t in r]) for r in x]
1469
1470     # trim all the tensors in the tuple z to remove as much token from
1471     # left and right in the first tensor. If z is a tuple, all its
1472     # elements are trimed according to the triming for the first
1473     def trim(self, z, token="#"):
1474         n = self.token2id[token]
1475         if type(z) == tuple:
1476             x = z[0]
1477             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1478             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1479             return tuple([t[:, a:b] for t in z])
1480         else:
1481             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1482             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1483             return z[:, a:b]
1484
1485     ######################
1486
1487     def __init__(
1488         self,
1489         nb_train_samples,
1490         nb_test_samples,
1491         batch_size,
1492         size,
1493         nb_shapes,
1494         nb_colors,
1495         logger=None,
1496         device=torch.device("cpu"),
1497     ):
1498         super().__init__()
1499
1500         self.device = device
1501         self.batch_size = batch_size
1502         self.grid_factory = grid.GridFactory(
1503             size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
1504         )
1505
1506         if logger is not None:
1507             logger(
1508                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1509             )
1510
1511         self.train_descr = self.grid_factory.generate_samples(
1512             nb_train_samples, lambda r: tqdm.tqdm(r)
1513         )
1514         self.test_descr = self.grid_factory.generate_samples(
1515             nb_test_samples, lambda r: tqdm.tqdm(r)
1516         )
1517
1518         # Build the tokenizer
1519         tokens = set()
1520         for d in [self.train_descr, self.test_descr]:
1521             for s in d:
1522                 for t in s.strip().split(" "):
1523                     tokens.add(t)
1524         # make this set a sorted list to get the same tensors given
1525         # the same descr
1526         tokens = list(tokens)
1527         tokens.sort()
1528         tokens = ["#"] + tokens
1529         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1530         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1531         self.t_nul = self.token2id["#"]
1532         self.t_true = self.token2id["true"]
1533         self.t_false = self.token2id["false"]
1534
1535         # Tokenize the train and test sets
1536         self.train_input = self.str2tensor(self.train_descr)
1537         self.test_input = self.str2tensor(self.test_descr)
1538
1539     def batches(self, split="train", desc=None):
1540         assert split in {"train", "test"}
1541         input = self.train_input if split == "train" else self.test_input
1542         if desc is None:
1543             desc = f"epoch-{split}"
1544         for batch in tqdm.tqdm(
1545             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1546         ):
1547             yield self.trim(batch)
1548
1549     def vocabulary_size(self):
1550         return len(self.token2id)
1551
1552     def produce_results(
1553         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1554     ):
1555         correct = self.test_input[:1000]
1556         result = correct.clone()
1557         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1558         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1559
1560         logger(f"----------------------------------------------------------")
1561
1562         for e in self.tensor2str(result[:10]):
1563             logger(f"test_before {e}")
1564
1565         masked_inplace_autoregression(
1566             model,
1567             self.batch_size,
1568             result,
1569             ar_mask,
1570             deterministic_synthesis,
1571             device=self.device,
1572         )
1573
1574         logger(f"----------------------------------------------------------")
1575
1576         for e in self.tensor2str(result[:10]):
1577             logger(f"test_after  {e}")
1578
1579         logger(f"----------------------------------------------------------")
1580
1581         nb_total = ar_mask.sum().item()
1582         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1583
1584         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1585         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1586
1587
1588 ######################################################################
1589
1590 import qmlp
1591
1592
1593 class QMLP(Task):
1594     ######################
1595
1596     def __init__(
1597         self,
1598         nb_train_samples,
1599         nb_test_samples,
1600         batch_size,
1601         result_dir,
1602         logger=None,
1603         device=torch.device("cpu"),
1604     ):
1605         super().__init__()
1606
1607         self.device = device
1608         self.batch_size = batch_size
1609         self.nb_samples_per_mlp = 256
1610
1611         if logger is not None:
1612             logger(
1613                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1614             )
1615
1616         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1617             nb_mlps=nb_train_samples + nb_test_samples,
1618             nb_samples=self.nb_samples_per_mlp,
1619             device=self.device,
1620             batch_size=64,
1621             nb_epochs=250,
1622             nb_mlps_per_batch=1024,
1623         )
1624
1625         self.train_input = seq[:nb_train_samples]
1626         self.train_q_test_set = q_test_set[:nb_train_samples]
1627         self.train_ref_test_errors = test_error[:nb_train_samples]
1628         self.test_input = seq[nb_train_samples:]
1629         self.test_q_test_set = q_test_set[nb_train_samples:]
1630         self.test_ref_test_errors = test_error[nb_train_samples:]
1631
1632         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1633         with open(filename, "w") as f:
1634             for e in self.train_ref_test_errors:
1635                 f.write(f"{e}\n")
1636
1637         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1638         with open(filename, "w") as f:
1639             for e in self.test_ref_test_errors:
1640                 f.write(f"{e}\n")
1641
1642         self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
1643
1644     def batches(self, split="train", desc=None):
1645         assert split in {"train", "test"}
1646         input = self.train_input if split == "train" else self.test_input
1647         if desc is None:
1648             desc = f"epoch-{split}"
1649         for batch in tqdm.tqdm(
1650             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1651         ):
1652             yield batch
1653
1654     def vocabulary_size(self):
1655         return self.nb_codes
1656
1657     def produce_results(
1658         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1659     ):
1660         correct = self.test_input[:1000]
1661         result = correct.clone()
1662         ar_mask = (
1663             torch.arange(result.size(1), device=result.device)
1664             > self.nb_samples_per_mlp * 3 + 1
1665         ).long()[None, :]
1666         ar_mask = ar_mask.expand_as(result)
1667         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1668
1669         masked_inplace_autoregression(
1670             model,
1671             self.batch_size,
1672             result,
1673             ar_mask,
1674             deterministic_synthesis,
1675             device=self.device,
1676         )
1677
1678         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1679         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1680         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1681
1682         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1683         with open(filename, "w") as f:
1684             for e in error_test:
1685                 f.write(f"{e}\n")
1686
1687
1688 ######################################################################