78910a06d7a24977b462af0836a9219073381507
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input, ar_mask, forbidden_tokens, deterministic_synthesis
52             )
53
54         model.train(t)
55
56
57 ######################################################################
58
59
60 class Task:
61     def batches(self, split="train"):
62         pass
63
64     def vocabulary_size(self):
65         pass
66
67     def produce_results(
68         self, n_epoch, model, result_dir, logger, deterministic_synthesis
69     ):
70         pass
71
72
73 class TaskFromFile(Task):
74     def tensorize(self, pairs):
75         len_max = max([len(x[0]) for x in pairs])
76
77         input = torch.cat(
78             [
79                 torch.tensor(
80                     [
81                         [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
82                         for s in pairs
83                     ]
84                 )
85             ],
86             0,
87         ).to("cpu")
88
89         pred_mask = torch.cat(
90             [
91                 torch.tensor(
92                     [
93                         [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
94                         for s in pairs
95                     ]
96                 )
97             ],
98             0,
99         ).to("cpu")
100
101         return input, pred_mask
102
103     # trim all the tensors in the tuple z to remove as much token from
104     # left and right in the first tensor. If z is a tuple, all its
105     # elements are trimed according to the triming for the first
106     def trim(self, z, token="#"):
107         n = self.char2id[token]
108         if type(z) == tuple:
109             x = z[0]
110             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
111             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
112             return tuple([t[:, a:b] for t in z])
113         else:
114             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
115             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
116             return z[:, a:b]
117
118     def __init__(
119         self,
120         filename,
121         nb_train_samples,
122         nb_test_samples,
123         batch_size,
124         device=torch.device("cpu"),
125     ):
126         self.batch_size = batch_size
127         self.device = device
128
129         pairs = []
130         with open(filename, "r") as f:
131             for _ in range(nb_train_samples + nb_test_samples):
132                 sequence = f.readline().strip()
133                 pred_mask = f.readline().strip()
134                 assert len(sequence) == len(pred_mask)
135                 assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
136                 pairs.append((sequence, pred_mask))
137
138         symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
139         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
140         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
141
142         self.train_input, self.train_pred_masks = self.tensorize(
143             pairs[:nb_train_samples]
144         )
145         self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
146
147     def batches(self, split="train", nb_to_use=-1, desc=None):
148         assert split in {"train", "test"}
149         input = self.train_input if split == "train" else self.test_input
150         if nb_to_use > 0:
151             input = input[:nb_to_use]
152         if desc is None:
153             desc = f"epoch-{split}"
154         for batch in tqdm.tqdm(
155             input.split(self.batch_size), dynamic_ncols=True, desc=desc
156         ):
157             yield self.trim(batch).to(self.device)
158
159     def vocabulary_size(self):
160         return len(self.char2id)
161
162     def tensor2str(self, t):
163         return ["".join([self.id2char[x.item()] for x in s]) for s in t]
164
165     def produce_results(
166         self, n_epoch, model, result_dir, logger, deterministic_synthesis
167     ):
168         correct = self.trim(self.test_input[:1000]).to(self.device)
169         result = correct.clone()
170         pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
171         ar_mask = (pred_mask > 0).long()
172         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
173
174         logger(f"----------------------------------------------------------")
175
176         for e in self.tensor2str(result[:10]):
177             logger(f"test_before {e}")
178
179         masked_inplace_autoregression(
180             model,
181             self.batch_size,
182             result,
183             ar_mask,
184             deterministic_synthesis,
185             device=self.device,
186         )
187
188         logger(f"----------------------------------------------------------")
189
190         for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
191             logger(f"test_after  {e}")
192             logger(f"correct     {c}")
193
194         logger(f"----------------------------------------------------------")
195
196         err_mask = (pred_mask == 2).long()
197         nb_total = err_mask.sum().item()
198         nb_correct = ((correct == result).long() * err_mask).sum().item()
199
200         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
201         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
202
203
204 ####################
205
206 import problems
207
208
209 class SandBox(Task):
210     def __init__(
211         self,
212         problem,
213         nb_train_samples,
214         nb_test_samples,
215         batch_size,
216         logger=None,
217         device=torch.device("cpu"),
218         max_nb_codes=1024,
219     ):
220         super().__init__()
221
222         self.batch_size = batch_size
223         self.device = device
224         self.problem = problem
225
226         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
227             nb_train_samples
228         )
229         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
230             nb_test_samples
231         )
232
233         self.train_input, self.train_ar_mask = self.train_input.to(
234             device
235         ), self.train_ar_mask.to(device)
236         self.test_input, self.test_ar_mask = self.test_input.to(
237             device
238         ), self.test_ar_mask.to(device)
239
240         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
241
242         # A bit of paranoia never hurts
243         assert self.nb_codes <= max_nb_codes
244         assert self.train_input.min() >= 0
245         assert self.test_input.min() >= 0
246         assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
247             (0,),
248             (1,),
249             (0, 1),
250         }
251         assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
252             (0,),
253             (1,),
254             (0, 1),
255         }
256
257         if logger is not None:
258             for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
259                 logger(f"train_sequences {self.problem.seq2str(s)}")
260                 a = "".join(["01"[x.item()] for x in a])
261                 logger(f"                {a}")
262
263     def batches(self, split="train", nb_to_use=-1, desc=None):
264         assert split in {"train", "test"}
265         input = self.train_input if split == "train" else self.test_input
266         if nb_to_use > 0:
267             input = input[:nb_to_use]
268         if desc is None:
269             desc = f"epoch-{split}"
270         for batch in tqdm.tqdm(
271             input.split(self.batch_size), dynamic_ncols=True, desc=desc
272         ):
273             yield batch
274
275     def vocabulary_size(self):
276         return self.nb_codes
277
278     def produce_results(
279         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
280     ):
281         def compute_accuracy(input, ar_mask, logger=None):
282             input, ar_mask = input[:nmax], ar_mask[:nmax]
283             result = input.clone() * (1 - ar_mask)
284
285             masked_inplace_autoregression(
286                 model,
287                 self.batch_size,
288                 result,
289                 ar_mask,
290                 deterministic_synthesis,
291                 progress_bar_desc=None,
292                 device=self.device,
293             )
294
295             log_ground_truth = ar_mask.min() == 0
296
297             if logger is not None:
298                 for sp, st in zip(result[:10], input[:10]):
299                     logger(
300                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
301                     )
302                     if log_ground_truth:
303                         logger(
304                             f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
305                         )
306
307             nb_total, nb_correct = self.problem.compute_nb_correct(
308                 input, ar_mask, result
309             )
310
311             # nb_total = ar_mask.sum().item()
312             # nb_correct = ((result == input).long() * ar_mask).sum().item()
313
314             return nb_total, nb_correct
315
316         train_nb_total, train_nb_correct = compute_accuracy(
317             self.train_input, self.train_ar_mask
318         )
319
320         logger(
321             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
322         )
323
324         test_nb_total, test_nb_correct = compute_accuracy(
325             self.test_input, self.test_ar_mask, logger
326         )
327
328         logger(
329             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
330         )
331
332         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
333
334         if save_attention_image is not None:
335             for k in range(10):
336                 ns = torch.randint(self.test_input.size(0), (1,)).item()
337                 input = self.test_input[ns : ns + 1].clone()
338
339                 with torch.autograd.no_grad():
340                     t = model.training
341                     model.eval()
342                     # model.record_attention(True)
343                     model(BracketedSequence(input))
344                     model.train(t)
345                     # ram = model.retrieve_attention()
346                     # model.record_attention(False)
347
348                 # tokens_output = [c for c in self.problem.seq2str(input[0])]
349                 # tokens_input = ["n/a"] + tokens_output[:-1]
350                 # for n_head in range(ram[0].size(1)):
351                 # filename = os.path.join(
352                 # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
353                 # )
354                 # attention_matrices = [m[0, n_head] for m in ram]
355                 # save_attention_image(
356                 # filename,
357                 # tokens_input,
358                 # tokens_output,
359                 # attention_matrices,
360                 # k_top=10,
361                 ##min_total_attention=0.9,
362                 # token_gap=12,
363                 # layer_gap=50,
364                 # )
365                 # logger(f"wrote {filename}")
366
367
368 ######################################################################
369
370 import picoclvr
371
372
373 class PicoCLVR(Task):
374     # Make a tensor from a list of strings
375     def tensorize(self, descr):
376         token_descr = [s.strip().split(" ") for s in descr]
377         l = max([len(s) for s in token_descr])
378         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
379         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
380         return torch.tensor(id_descr, device=self.device)
381
382     # Make a list of strings from a tensor
383     def detensorize(self, x):
384         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
385
386     # trim all the tensors in the tuple z to remove as much token from
387     # left and right in the first tensor. If z is a tuple, all its
388     # elements are trimed according to the triming for the first
389     def trim(self, z, token="<nul>"):
390         n = self.token2id[token]
391         if type(z) == tuple:
392             x = z[0]
393             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
394             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
395             return tuple([t[:, a:b] for t in z])
396         else:
397             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
398             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
399             return z[:, a:b]
400
401     ######################
402
403     def __init__(
404         self,
405         nb_train_samples,
406         nb_test_samples,
407         batch_size,
408         height,
409         width,
410         nb_colors=5,
411         logger=None,
412         device=torch.device("cpu"),
413         pruner_train=None,
414         pruner_eval=None,
415     ):
416         super().__init__()
417
418         def generate_descr(nb, cache_suffix, pruner):
419             return picoclvr.generate(
420                 nb,
421                 height=self.height,
422                 width=self.width,
423                 nb_colors=nb_colors,
424                 pruner=pruner,
425             )
426
427         self.height = height
428         self.width = width
429         self.batch_size = batch_size
430         self.device = device
431         self.pruner_train = pruner_train
432         self.pruner_eval = pruner_eval
433
434         if logger is not None:
435             logger(
436                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
437             )
438
439         self.train_descr = generate_descr(
440             nb_train_samples, "train", pruner=self.pruner_train
441         )
442         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
443
444         # Build the tokenizer
445         tokens = {"<nul>", "<img>"}
446         for d in [self.train_descr, self.test_descr]:
447             for s in d:
448                 for t in s.strip().split(" "):
449                     tokens.add(t)
450         # make this set a sorted list to get the same tensors given
451         # the same descr
452         tokens = list(tokens)
453         tokens.sort()
454         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
455         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
456         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
457
458         # Tokenize the train and test sets
459         self.train_input = self.tensorize(self.train_descr)
460         self.test_input = self.tensorize(self.test_descr)
461
462     def batches(self, split="train"):
463         assert split in {"train", "test"}
464         input = self.train_input if split == "train" else self.test_input
465         for batch in tqdm.tqdm(
466             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
467         ):
468             yield self.trim(batch)
469
470     def vocabulary_size(self):
471         return len(self.token2id)
472
473     def compute_missing_properties(
474         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
475     ):
476         acc_nb_requested_properties = []
477         acc_nb_missing_properties = []
478         acc_nb_results = 0
479
480         for input in tqdm.tqdm(
481             self.test_input.split(self.batch_size),
482             dynamic_ncols=True,
483             desc=f"test-properties",
484         ):
485             result = input.clone()
486             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
487             result = (1 - ar_mask) * result + ar_mask * self.t_nul
488             masked_inplace_autoregression(
489                 model,
490                 self.batch_size,
491                 result,
492                 ar_mask,
493                 deterministic_synthesis,
494                 progress_bar_desc=None,
495                 device=self.device,
496             )
497
498             result_descr = self.detensorize(result)
499             np = picoclvr.nb_properties(
500                 result_descr,
501                 height=self.height,
502                 width=self.width,
503                 pruner=pruner,
504             )
505             nb_requested_properties, _, nb_missing_properties = zip(*np)
506             acc_nb_requested_properties += nb_requested_properties
507             acc_nb_missing_properties += nb_missing_properties
508             acc_nb_results += len(result_descr)
509
510         nb_requested_properties = sum(acc_nb_requested_properties)
511         nb_missing_properties = sum(acc_nb_missing_properties)
512
513         prefix = "" if pruner is None else "pruned_"
514         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
515         logger(
516             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
517         )
518         logger(
519             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
520         )
521
522         logger(
523             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
524         )
525
526     ######################################################################
527
528     def produce_results(
529         self, n_epoch, model, result_dir, logger, deterministic_synthesis
530     ):
531         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
532
533         if self.pruner_eval is not None:
534             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
535
536         nb_tokens_to_generate = self.height * self.width + 3
537         result_descr = []
538         nb_per_primer = 8
539         primer = []
540
541         for primer_descr in [
542             "red above green <sep> green top <sep> blue right of red",
543             "there is red <sep> there is yellow <sep> there is blue",
544             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
545             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
546         ]:
547             primer += [primer_descr + " <img>"] * nb_per_primer
548
549         result = self.tensorize(primer)
550         fill = result.new_full(
551             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
552         )
553         result = torch.cat((result, fill), 1)
554         ar_mask = (result == self.t_nul).long()
555         masked_inplace_autoregression(
556             model,
557             self.batch_size,
558             result,
559             ar_mask,
560             deterministic_synthesis,
561             device=self.device,
562         )
563         result_descr = self.detensorize(result)
564
565         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
566
567         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
568         acc_nb_results = len(result_descr)
569
570         nb_requested_properties = sum(acc_nb_requested_properties)
571         nb_missing_properties = sum(acc_nb_missing_properties)
572
573         prefix = "demo_"
574         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
575         logger(
576             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
577         )
578         logger(
579             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
580         )
581
582         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
583
584         if img.dim() == 5:
585             if img.size(1) == 1:
586                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
587             else:
588                 img = torch.cat(
589                     [
590                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
591                         for x in img
592                     ],
593                     0,
594                 )
595
596         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
597         torchvision.utils.save_image(
598             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
599         )
600         logger(f"wrote {image_name}")
601
602
603 ######################################################################
604
605
606 class MNIST(Task):
607     def __init__(
608         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
609     ):
610         super().__init__()
611
612         self.nb_train_samples = (nb_train_samples,)
613         self.nb_test_samples = (nb_test_samples,)
614         self.batch_size = batch_size
615         self.device = device
616         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
617         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
618         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
619         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
620
621     def batches(self, split="train", nb_to_use=-1, desc=None):
622         assert split in {"train", "test"}
623         input = self.train_input if split == "train" else self.test_input
624         if nb_to_use > 0:
625             input = input[:nb_to_use]
626         if desc is None:
627             desc = f"epoch-{split}"
628         for batch in tqdm.tqdm(
629             input.split(self.batch_size), dynamic_ncols=True, desc=desc
630         ):
631             yield batch
632
633     def vocabulary_size(self):
634         return 256
635
636     def produce_results(
637         self, n_epoch, model, result_dir, logger, deterministic_synthesis
638     ):
639         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
640         ar_mask = torch.full_like(results, 1)
641         masked_inplace_autoregression(
642             model,
643             self.batch_size,
644             results,
645             ar_mask,
646             deterministic_synthesis,
647             device=self.device,
648         )
649         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
650         torchvision.utils.save_image(
651             1 - results.reshape(-1, 1, 28, 28) / 255.0,
652             image_name,
653             nrow=16,
654             pad_value=0.8,
655         )
656         logger(f"wrote {image_name}")
657
658
659 ######################################################################
660
661 import maze
662
663
664 class Maze(Task):
665     def map2seq(self, *m):
666         return torch.cat([x.flatten(1) for x in m], 1)
667
668     def seq2map(self, s):
669         s = s.reshape(s.size(0), -1, self.height, self.width)
670         return (s[:, k] for k in range(s.size(1)))
671
672     def __init__(
673         self,
674         nb_train_samples,
675         nb_test_samples,
676         batch_size,
677         height,
678         width,
679         nb_walls,
680         device=torch.device("cpu"),
681     ):
682         super().__init__()
683
684         self.batch_size = batch_size
685         self.height = height
686         self.width = width
687         self.device = device
688
689         train_mazes, train_paths, _ = maze.create_maze_data(
690             nb_train_samples,
691             height=height,
692             width=width,
693             nb_walls=nb_walls,
694             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
695         )
696         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
697
698         test_mazes, test_paths, _ = maze.create_maze_data(
699             nb_test_samples,
700             height=height,
701             width=width,
702             nb_walls=nb_walls,
703             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
704         )
705         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
706
707         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
708
709     def batches(self, split="train", nb_to_use=-1, desc=None):
710         assert split in {"train", "test"}
711         input = self.train_input if split == "train" else self.test_input
712         if nb_to_use > 0:
713             input = input[:nb_to_use]
714         if desc is None:
715             desc = f"epoch-{split}"
716         for batch in tqdm.tqdm(
717             input.split(self.batch_size), dynamic_ncols=True, desc=desc
718         ):
719             yield batch
720
721     def vocabulary_size(self):
722         return self.nb_codes
723
724     def compute_error(
725         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
726     ):
727         nb_total, nb_correct = 0, 0
728         count = torch.zeros(
729             self.width * self.height,
730             self.width * self.height,
731             device=self.device,
732             dtype=torch.int64,
733         )
734
735         for input in self.batches(split, nb_to_use):
736             result = input.clone()
737             ar_mask = result.new_zeros(result.size())
738             ar_mask[:, self.height * self.width :] = 1
739             result *= 1 - ar_mask
740             masked_inplace_autoregression(
741                 model,
742                 self.batch_size,
743                 result,
744                 ar_mask,
745                 deterministic_synthesis,
746                 progress_bar_desc=None,
747                 device=self.device,
748             )
749             mazes, paths = self.seq2map(result)
750             path_correctness = maze.path_correctness(mazes, paths)
751             nb_correct += path_correctness.long().sum()
752             nb_total += mazes.size(0)
753
754             optimal_path_lengths = (
755                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
756             )
757             predicted_path_lengths = (
758                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
759             )
760             optimal_path_lengths = optimal_path_lengths[path_correctness]
761             predicted_path_lengths = predicted_path_lengths[path_correctness]
762             count[optimal_path_lengths, predicted_path_lengths] += 1
763
764         if count.max() == 0:
765             count = None
766         else:
767             count = count[
768                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
769             ]
770
771         return nb_total, nb_correct, count
772
773     def produce_results(
774         self, n_epoch, model, result_dir, logger, deterministic_synthesis
775     ):
776         train_nb_total, train_nb_correct, count = self.compute_error(
777             model,
778             "train",
779             nb_to_use=1000,
780             deterministic_synthesis=deterministic_synthesis,
781         )
782         logger(
783             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
784         )
785
786         test_nb_total, test_nb_correct, count = self.compute_error(
787             model,
788             "test",
789             nb_to_use=1000,
790             deterministic_synthesis=deterministic_synthesis,
791         )
792         logger(
793             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
794         )
795
796         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
797
798         if count is not None:
799             proportion_optimal = count.diagonal().sum().float() / count.sum()
800             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
801             with open(
802                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
803             ) as f:
804                 for i in range(count.size(0)):
805                     for j in range(count.size(1)):
806                         eol = " " if j < count.size(1) - 1 else "\n"
807                         f.write(f"{count[i,j]}{eol}")
808
809         input = self.test_input[:48]
810         result = input.clone()
811         ar_mask = result.new_zeros(result.size())
812         ar_mask[:, self.height * self.width :] = 1
813         result *= 1 - ar_mask
814         masked_inplace_autoregression(
815             model,
816             self.batch_size,
817             result,
818             ar_mask,
819             deterministic_synthesis,
820             device=self.device,
821         )
822
823         mazes, paths = self.seq2map(input)
824         _, predicted_paths = self.seq2map(result)
825
826         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
827         maze.save_image(
828             filename,
829             mazes=mazes,
830             target_paths=paths,
831             predicted_paths=predicted_paths,
832             path_correct=maze.path_correctness(mazes, predicted_paths),
833             path_optimal=maze.path_optimality(paths, predicted_paths),
834         )
835         logger(f"wrote {filename}")
836
837
838 ######################################################################
839
840
841 import snake
842
843
844 class Snake(Task):
845     def __init__(
846         self,
847         nb_train_samples,
848         nb_test_samples,
849         batch_size,
850         height,
851         width,
852         nb_colors,
853         length,
854         prompt_length,
855         device=torch.device("cpu"),
856     ):
857         super().__init__()
858
859         self.batch_size = batch_size
860         self.height = height
861         self.width = width
862         self.device = device
863         self.prompt_length = prompt_length
864
865         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
866             nb_train_samples,
867             height,
868             width,
869             nb_colors,
870             length,
871             prompt_length,
872             self.device,
873         )
874         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
875             nb_test_samples,
876             height,
877             width,
878             nb_colors,
879             length,
880             prompt_length,
881             self.device,
882         )
883
884         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
885
886     def batches(self, split="train", nb_to_use=-1, desc=None):
887         assert split in {"train", "test"}
888         input = self.train_input if split == "train" else self.test_input
889         if nb_to_use > 0:
890             input = input[:nb_to_use]
891         if desc is None:
892             desc = f"epoch-{split}"
893         for batch in tqdm.tqdm(
894             input.split(self.batch_size), dynamic_ncols=True, desc=desc
895         ):
896             yield batch
897
898     def vocabulary_size(self):
899         return self.nb_codes
900
901     def produce_results(
902         self, n_epoch, model, result_dir, logger, deterministic_synthesis
903     ):
904         def compute_nb_correct(input, prior_visits):
905             result = input.clone()
906             i = torch.arange(result.size(1), device=result.device)[None, :]
907             ar_mask = (
908                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
909                 .long()
910                 .expand_as(result)
911             )
912             result *= 1 - ar_mask
913
914             masked_inplace_autoregression(
915                 model,
916                 self.batch_size,
917                 result,
918                 ar_mask,
919                 deterministic_synthesis,
920                 device=self.device,
921             )
922
923             nb_total = ((prior_visits > 0) * ar_mask).sum()
924
925             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
926
927             return nb_total, nb_correct
928
929         test_nb_total, test_nb_correct = compute_nb_correct(
930             self.test_input[:1000], self.test_prior_visits[:1000]
931         )
932
933         logger(
934             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
935         )
936
937         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
938
939
940 ######################################################################
941
942
943 import stack
944
945
946 class Stack(Task):
947     def __init__(
948         self,
949         nb_train_samples,
950         nb_test_samples,
951         batch_size,
952         logger,
953         nb_steps,
954         nb_stacks,
955         nb_digits,
956         fraction_values_for_train=None,
957         device=torch.device("cpu"),
958     ):
959         super().__init__()
960
961         self.batch_size = batch_size
962         self.nb_steps = nb_steps
963         self.nb_stacks = nb_stacks
964         self.nb_digits = nb_digits
965         self.device = device
966
967         if fraction_values_for_train is None:
968             values_for_train = None
969             values_for_test = None
970         else:
971             all = torch.randperm(10**nb_digits)
972             nb_for_train = int(all.size(0) * fraction_values_for_train)
973             values_for_train = all[:nb_for_train]
974             values_for_test = all[nb_for_train:]
975
976         self.train_input, self.train_stack_counts = stack.generate_sequences(
977             nb_train_samples,
978             nb_steps,
979             nb_stacks,
980             nb_digits,
981             values_for_train,
982             self.device,
983         )
984
985         self.test_input, self.test_stack_counts = stack.generate_sequences(
986             nb_test_samples,
987             nb_steps,
988             nb_stacks,
989             nb_digits,
990             values_for_test,
991             self.device,
992         )
993
994         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
995         counts = self.test_stack_counts.flatten()[i.flatten()]
996         counts = F.one_hot(counts).sum(0)
997         logger(f"test_pop_stack_counts {counts}")
998
999         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1000
1001     def batches(self, split="train", nb_to_use=-1, desc=None):
1002         assert split in {"train", "test"}
1003         input = self.train_input if split == "train" else self.test_input
1004         if nb_to_use > 0:
1005             input = input[:nb_to_use]
1006         if desc is None:
1007             desc = f"epoch-{split}"
1008         for batch in tqdm.tqdm(
1009             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1010         ):
1011             yield batch
1012
1013     def vocabulary_size(self):
1014         return self.nb_codes
1015
1016     def produce_results(
1017         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1018     ):
1019         def compute_nb_correct(input):
1020             result = input.clone()
1021             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1022             ar_mask = (result != input).long()
1023             masked_inplace_autoregression(
1024                 model,
1025                 self.batch_size,
1026                 result,
1027                 ar_mask,
1028                 deterministic_synthesis,
1029                 device=self.device,
1030             )
1031
1032             errors = ((result != input).long() * ar_mask).reshape(
1033                 -1, 1 + self.nb_digits
1034             )
1035             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
1036
1037             nb_total = ar_mask.max(1).values.sum()
1038             nb_correct = nb_total - errors.max(1).values.sum()
1039
1040             return nb_total, nb_correct
1041
1042         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
1043
1044         logger(
1045             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1046         )
1047
1048         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1049
1050         ##############################################################
1051         # Log a few generated sequences
1052         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
1053         result = input.clone()
1054         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
1055         ar_mask = (result != input).long()
1056
1057         # for n in range(result.size(0)):
1058         # logger(
1059         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1060         # )
1061
1062         masked_inplace_autoregression(
1063             model,
1064             self.batch_size,
1065             result,
1066             ar_mask,
1067             deterministic_synthesis,
1068             device=self.device,
1069         )
1070
1071         for n in range(result.size(0)):
1072             logger(
1073                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
1074             )
1075         ##############################################################
1076
1077
1078 ######################################################################
1079
1080 import rpl
1081
1082
1083 class RPL(Task):
1084     def tensorize(self, sequences):
1085         len_max = max([len(x) for x in sequences])
1086         return torch.cat(
1087             [
1088                 torch.tensor(
1089                     [
1090                         [
1091                             self.token2id[str(c)]
1092                             for c in s + ["<nul>"] * (len_max - len(s))
1093                         ]
1094                         for s in sequences
1095                     ]
1096                 )
1097             ],
1098             0,
1099         )
1100
1101     def seq2str(self, seq):
1102         return " ".join([self.id2token[i] for i in seq])
1103
1104     def __init__(
1105         self,
1106         nb_train_samples,
1107         nb_test_samples,
1108         batch_size,
1109         nb_starting_values=3,
1110         max_input=9,
1111         prog_len=6,
1112         nb_runs=5,
1113         no_prog=False,
1114         logger=None,
1115         device=torch.device("cpu"),
1116     ):
1117         super().__init__()
1118
1119         self.batch_size = batch_size
1120         self.device = device
1121         self.no_prog = no_prog
1122
1123         train_sequences = [
1124             rpl.generate(
1125                 nb_starting_values=nb_starting_values,
1126                 nb_result_values_max=4 * nb_starting_values,
1127                 max_input=max_input,
1128                 prog_len=prog_len,
1129                 nb_runs=nb_runs,
1130             )
1131             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
1132         ]
1133
1134         test_sequences = [
1135             rpl.generate(
1136                 nb_starting_values=nb_starting_values,
1137                 nb_result_values_max=4 * nb_starting_values,
1138                 max_input=max_input,
1139                 prog_len=prog_len,
1140                 nb_runs=nb_runs,
1141             )
1142             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1143         ]
1144
1145         symbols = list(
1146             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1147         )
1148         val_max = max([x if type(x) is int else 0 for x in symbols])
1149         symbols = list(filter(lambda x: type(x) is str, symbols))
1150         symbols.sort()
1151         symbols += [str(n) for n in range(val_max + 1)]
1152         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1153         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1154
1155         self.t_nul = self.token2id["<nul>"]
1156         self.t_input = self.token2id["<in>"]
1157         self.t_output = self.token2id["<out>"]
1158         self.t_prog = self.token2id["<prg>"]
1159         self.t_end = self.token2id["<end>"]
1160
1161         self.train_input = self.tensorize(train_sequences)
1162         self.test_input = self.tensorize(test_sequences)
1163
1164         if no_prog:
1165             # Excise the program from every train and test example
1166             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1167                 None, :
1168             ]
1169             p = (
1170                 ((self.train_input == self.t_prog).long() * k)
1171                 .max(1, keepdim=True)
1172                 .values
1173             )
1174             self.train_input = (
1175                 self.train_input * (k <= p).long()
1176                 + self.t_end * (k == p + 1).long()
1177                 + self.t_nul * (k > p + 1).long()
1178             )
1179             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1180                 None, :
1181             ]
1182             p = (
1183                 ((self.test_input == self.t_prog).long() * k)
1184                 .max(1, keepdim=True)
1185                 .values
1186             )
1187             self.test_input = (
1188                 self.test_input * (k <= p).long()
1189                 + self.t_end * (k == p + 1).long()
1190                 + self.t_nul * (k > p + 1).long()
1191             )
1192
1193         if logger is not None:
1194             logger(f"value_max {val_max}")
1195             for x in self.train_input[:25]:
1196                 end = (x != self.t_nul).nonzero().max().item() + 1
1197                 seq = [self.id2token[i.item()] for i in x[:end]]
1198                 s = " ".join(seq)
1199                 logger(f"example_seq {s}")
1200
1201         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1202
1203     def batches(self, split="train", nb_to_use=-1, desc=None):
1204         assert split in {"train", "test"}
1205         input = self.train_input if split == "train" else self.test_input
1206         if nb_to_use > 0:
1207             input = input[:nb_to_use]
1208         if desc is None:
1209             desc = f"epoch-{split}"
1210         for batch in tqdm.tqdm(
1211             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1212         ):
1213             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1214             batch = batch[:, :last].to(self.device)
1215             yield batch
1216
1217     def vocabulary_size(self):
1218         return self.nb_codes
1219
1220     def produce_results(
1221         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1222     ):
1223         # --------------------------------------------------------------------
1224         def compute_nb_errors_prog(input, nb_to_log=0):
1225             result = input.clone()
1226             s = (result == self.t_prog).long()
1227             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1228             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1229
1230             masked_inplace_autoregression(
1231                 model,
1232                 self.batch_size,
1233                 result,
1234                 ar_mask,
1235                 deterministic_synthesis,
1236                 device=self.device,
1237             )
1238
1239             sum_nb_total, sum_nb_errors = 0, 0
1240             for one_input, one_result in zip(input, result):
1241                 seq = [self.id2token[i.item()] for i in one_result]
1242                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1243                 sum_nb_total += 1
1244                 sum_nb_errors += 0 if nb_errors == 0 else 1
1245                 if nb_to_log > 0:
1246                     gt_seq = [self.id2token[i.item()] for i in one_input]
1247                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1248                     gt_prog = " ".join([str(x) for x in gt_prog])
1249                     prog = " ".join([str(x) for x in prog])
1250                     comment = "*" if nb_errors == 0 else "-"
1251                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1252                     for start_stack, target_stack, result_stack, correct in stacks:
1253                         comment = "*" if correct else "-"
1254                         start_stack = " ".join([str(x) for x in start_stack])
1255                         target_stack = " ".join([str(x) for x in target_stack])
1256                         result_stack = " ".join([str(x) for x in result_stack])
1257                         logger(
1258                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1259                         )
1260                     nb_to_log -= 1
1261
1262             return sum_nb_total, sum_nb_errors
1263
1264         # --------------------------------------------------------------------
1265         def compute_nb_errors_output(input, nb_to_log=0):
1266             result = input.clone()
1267             k = torch.arange(result.size(1), device=result.device)[None, :]
1268             last_output_idx = (
1269                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1270             )
1271             first_prog_idx = (
1272                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1273             )
1274             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1275             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1276
1277             masked_inplace_autoregression(
1278                 model,
1279                 self.batch_size,
1280                 result,
1281                 ar_mask,
1282                 deterministic_synthesis,
1283                 device=self.device,
1284             )
1285
1286             sum_nb_total, sum_nb_errors = 0, 0
1287             for one_input, one_result, i, j in zip(
1288                 input, result, last_output_idx, first_prog_idx
1289             ):
1290                 seq = [self.id2token[i.item()] for i in one_result]
1291                 sum_nb_total += 1
1292                 correct = (one_input - one_result).abs().max() == 0
1293                 sum_nb_errors += 0 if correct else 1
1294                 if nb_to_log > 0:
1295                     result_stack = [
1296                         self.id2token[i.item()] for i in one_result[i : j + 1]
1297                     ]
1298                     target_stack = [
1299                         self.id2token[i.item()] for i in one_input[i : j + 1]
1300                     ]
1301                     comment = "*" if correct else "-"
1302                     result_stack = " ".join([str(x) for x in result_stack])
1303                     target_stack = " ".join([str(x) for x in target_stack])
1304                     logger(
1305                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1306                     )
1307                     nb_to_log -= 1
1308
1309             return sum_nb_total, sum_nb_errors
1310
1311         # --------------------------------------------------------------------
1312
1313         if not self.no_prog:
1314             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1315                 self.test_input[:1000].to(self.device), nb_to_log=10
1316             )
1317
1318             logger(
1319                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1320             )
1321
1322             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1323
1324         test_nb_total, test_nb_errors = compute_nb_errors_output(
1325             self.test_input[:1000].to(self.device), nb_to_log=10
1326         )
1327
1328         logger(
1329             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1330         )
1331
1332         if save_attention_image is None:
1333             logger("no save_attention_image (is pycairo installed?)")
1334         else:
1335             ns = torch.randint(self.test_input.size(0), (1,)).item()
1336             input = self.test_input[ns : ns + 1].clone()
1337             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1338             input = input[:, :last].to(self.device)
1339
1340             with torch.autograd.no_grad():
1341                 t = model.training
1342                 model.eval()
1343                 model.record_attention(True)
1344                 model(BracketedSequence(input))
1345                 model.train(t)
1346                 ram = model.retrieve_attention()
1347                 model.record_attention(False)
1348
1349             tokens_output = [self.id2token[i.item()] for i in input[0]]
1350             tokens_input = ["n/a"] + tokens_output[:-1]
1351             for n_head in range(ram[0].size(1)):
1352                 filename = os.path.join(
1353                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1354                 )
1355                 attention_matrices = [m[0, n_head] for m in ram]
1356                 save_attention_image(
1357                     filename,
1358                     tokens_input,
1359                     tokens_output,
1360                     attention_matrices,
1361                     k_top=10,
1362                     # min_total_attention=0.9,
1363                     token_gap=12,
1364                     layer_gap=50,
1365                 )
1366                 logger(f"wrote {filename}")
1367
1368
1369 ######################################################################
1370
1371
1372 import expr
1373
1374
1375 class Expr(Task):
1376     def tensorize(self, sequences):
1377         len_max = max([len(x) for x in sequences])
1378         return torch.cat(
1379             [
1380                 torch.tensor(
1381                     [
1382                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1383                         for s in sequences
1384                     ]
1385                 )
1386             ],
1387             0,
1388         ).to(self.device)
1389
1390     def __init__(
1391         self,
1392         nb_train_samples,
1393         nb_test_samples,
1394         nb_variables,
1395         sequence_length,
1396         operand_max,
1397         result_max,
1398         batch_size,
1399         device=torch.device("cpu"),
1400     ):
1401         super().__init__()
1402
1403         self.batch_size = batch_size
1404         self.device = device
1405
1406         train_sequences = expr.generate_sequences(
1407             nb_train_samples,
1408             nb_variables=nb_variables,
1409             length=sequence_length,
1410             operand_max=operand_max,
1411             result_max=result_max,
1412         )
1413
1414         test_sequences = expr.generate_sequences(
1415             nb_test_samples,
1416             nb_variables=nb_variables,
1417             length=sequence_length,
1418             operand_max=operand_max,
1419             result_max=result_max,
1420         )
1421
1422         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1423         symbols.sort()
1424
1425         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1426         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1427
1428         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1429
1430         self.train_input = self.tensorize(train_sequences)
1431         self.test_input = self.tensorize(test_sequences)
1432
1433         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1434
1435     def batches(self, split="train", nb_to_use=-1, desc=None):
1436         assert split in {"train", "test"}
1437         input = self.train_input if split == "train" else self.test_input
1438         if nb_to_use > 0:
1439             input = input[:nb_to_use]
1440         if desc is None:
1441             desc = f"epoch-{split}"
1442         for batch in tqdm.tqdm(
1443             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1444         ):
1445             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1446             batch = batch[:, :last]
1447             yield batch
1448
1449     def vocabulary_size(self):
1450         return self.nb_codes
1451
1452     def seq2str(self, s):
1453         return "".join([self.id2char[k.item()] for k in s])
1454
1455     def produce_results(
1456         self,
1457         n_epoch,
1458         model,
1459         result_dir,
1460         logger,
1461         deterministic_synthesis,
1462         input_file=None,
1463     ):
1464         def compute_nb_correct(input):
1465             result = input.clone()
1466             s = (result == self.space).long()
1467             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1468             result = (1 - ar_mask) * result + ar_mask * self.filler
1469             masked_inplace_autoregression(
1470                 model,
1471                 self.batch_size,
1472                 result,
1473                 ar_mask,
1474                 deterministic_synthesis,
1475                 device=self.device,
1476             )
1477
1478             nb_total = input.size(0)
1479             nb_correct = (input == result).long().min(1).values.sum()
1480
1481             #######################################################################
1482             # Comput predicted vs. true variable values
1483
1484             nb_delta = torch.zeros(5, dtype=torch.int64)
1485             nb_missed = 0
1486
1487             values_input = expr.extract_results([self.seq2str(s) for s in input])
1488             values_result = expr.extract_results([self.seq2str(s) for s in result])
1489
1490             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1491
1492             with open(filename, "w") as f:
1493                 for i, r in zip(values_input, values_result):
1494                     for n, vi in i.items():
1495                         vr = r.get(n)
1496                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1497
1498                         if vr is None or vr < 0:
1499                             nb_missed += 1
1500                         else:
1501                             d = abs(vr - vi)
1502                             if d >= nb_delta.size(0):
1503                                 nb_missed += 1
1504                             else:
1505                                 nb_delta[d] += 1
1506
1507             ######################################################################
1508
1509             return nb_total, nb_correct, nb_delta, nb_missed
1510
1511         (
1512             test_nb_total,
1513             test_nb_correct,
1514             test_nb_delta,
1515             test_nb_missed,
1516         ) = compute_nb_correct(self.test_input[:10000])
1517
1518         logger(
1519             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1520         )
1521
1522         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1523
1524         nb_total = test_nb_delta.sum() + test_nb_missed
1525         for d in range(test_nb_delta.size(0)):
1526             logger(
1527                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1528             )
1529         logger(
1530             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1531         )
1532
1533         ##############################################################
1534         # Log a few generated sequences
1535         if input_file is None:
1536             input = self.test_input[:10]
1537         else:
1538             with open(input_file, "r") as f:
1539                 sequences = [e.strip() for e in f.readlines()]
1540                 sequences = [s + " " + "#" * 50 for s in sequences]
1541                 input = self.tensorize(sequences)
1542
1543         result = input.clone()
1544         s = (result == self.space).long()
1545         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1546         result = (1 - ar_mask) * result + ar_mask * self.filler
1547
1548         for n in range(result.size(0)):
1549             logger(f"test_before {self.seq2str(result[n])}")
1550
1551         masked_inplace_autoregression(
1552             model,
1553             self.batch_size,
1554             result,
1555             ar_mask,
1556             deterministic_synthesis,
1557             device=self.device,
1558         )
1559
1560         correct = (1 - ar_mask) * self.space + ar_mask * input
1561         for n in range(result.size(0)):
1562             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1563             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1564             logger(f"truth       {self.seq2str(correct[n])}")
1565         ##############################################################
1566
1567
1568 ######################################################################
1569
1570 import grid
1571
1572
1573 class Grid(Task):
1574     # Make a tensor from a list of strings
1575     def str2tensor(self, descr):
1576         token_descr = [s.strip().split(" ") for s in descr]
1577         l = max([len(s) for s in token_descr])
1578         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1579         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1580         return torch.tensor(id_descr, device=self.device)
1581
1582     # Make a list of strings from a tensor
1583     def tensor2str(self, x):
1584         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1585
1586     # trim all the tensors in the tuple z to remove as much token from
1587     # left and right in the first tensor. If z is a tuple, all its
1588     # elements are trimed according to the triming for the first
1589     def trim(self, z, token="#"):
1590         n = self.token2id[token]
1591         if type(z) == tuple:
1592             x = z[0]
1593             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1594             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1595             return tuple([t[:, a:b] for t in z])
1596         else:
1597             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1598             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1599             return z[:, a:b]
1600
1601     ######################
1602
1603     def __init__(
1604         self,
1605         nb_train_samples,
1606         nb_test_samples,
1607         batch_size,
1608         size,
1609         fraction_play=0.0,
1610         logger=None,
1611         device=torch.device("cpu"),
1612     ):
1613         super().__init__()
1614
1615         self.device = device
1616         self.batch_size = batch_size
1617         self.grid_factory = grid.GridFactory(size=size)
1618         self.fraction_play = fraction_play
1619
1620         if logger is not None:
1621             logger(
1622                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1623             )
1624
1625         self.train_descr = self.grid_factory.generate_samples(
1626             nb=nb_train_samples,
1627             fraction_play=fraction_play,
1628             progress_bar=lambda r: tqdm.tqdm(r),
1629         )
1630
1631         self.test_descr = self.grid_factory.generate_samples(
1632             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
1633         )
1634
1635         if fraction_play > 0:
1636             self.play_descr = self.grid_factory.generate_samples(
1637                 nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
1638             )
1639         else:
1640             self.play_descr = []
1641
1642         # Build the tokenizer
1643         tokens = set()
1644         for d in [self.train_descr, self.test_descr, self.play_descr]:
1645             for s in d:
1646                 for t in s.strip().split(" "):
1647                     tokens.add(t)
1648         # make this set a sorted list to get the same tensors given
1649         # the same descr
1650         tokens = list(tokens)
1651         tokens.sort()
1652         tokens = ["#"] + tokens
1653         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1654         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1655         self.t_nul = self.token2id["#"]
1656         self.t_true = self.token2id["true"]
1657         self.t_false = self.token2id["false"]
1658         self.t_pipe = self.token2id["|"]
1659
1660         # Tokenize the train and test sets
1661         self.train_input = self.str2tensor(self.train_descr)
1662         self.test_input = self.str2tensor(self.test_descr)
1663         self.play_input = (
1664             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
1665         )
1666
1667     def batches(self, split="train"):
1668         assert split in {"train", "test"}
1669         input = self.train_input if split == "train" else self.test_input
1670         for batch in tqdm.tqdm(
1671             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1672         ):
1673             yield self.trim(batch)
1674
1675     def vocabulary_size(self):
1676         return len(self.token2id)
1677
1678     def produce_results(
1679         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1680     ):
1681         correct = self.test_input[:1000]
1682         result = correct.clone()
1683         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1684         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1685
1686         logger(f"----------------------------------------------------------")
1687
1688         for e in self.tensor2str(result[:10]):
1689             logger(f"test_before {e}")
1690
1691         masked_inplace_autoregression(
1692             model,
1693             self.batch_size,
1694             result,
1695             ar_mask,
1696             deterministic_synthesis,
1697             device=self.device,
1698         )
1699
1700         logger(f"----------------------------------------------------------")
1701
1702         for e in self.tensor2str(result[:10]):
1703             logger(f"test_after  {e}")
1704
1705         logger(f"----------------------------------------------------------")
1706
1707         nb_total = ar_mask.sum().item()
1708         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1709
1710         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1711         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1712
1713         if self.play_input is not None:
1714             result = self.play_input.clone()
1715             ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
1716             result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1717
1718             logger(f"----------------------------------------------------------")
1719
1720             for e in self.tensor2str(result[:10]):
1721                 logger(f"play_before {e}")
1722
1723             masked_inplace_autoregression(
1724                 model,
1725                 self.batch_size,
1726                 result,
1727                 ar_mask,
1728                 deterministic_synthesis,
1729                 device=self.device,
1730             )
1731
1732             logger(f"----------------------------------------------------------")
1733
1734             for e in self.tensor2str(result[:10]):
1735                 logger(f"play_after  {e}")
1736
1737             logger(f"----------------------------------------------------------")
1738
1739
1740 ######################################################################
1741
1742 import qmlp
1743
1744
1745 class QMLP(Task):
1746     ######################
1747
1748     def __init__(
1749         self,
1750         nb_train_samples,
1751         nb_test_samples,
1752         batch_size,
1753         result_dir,
1754         logger=None,
1755         device=torch.device("cpu"),
1756     ):
1757         super().__init__()
1758
1759         self.device = device
1760         self.batch_size = batch_size
1761         self.nb_samples_per_mlp = 256
1762
1763         if logger is not None:
1764             logger(
1765                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1766             )
1767
1768         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1769             nb_mlps=nb_train_samples + nb_test_samples,
1770             nb_samples=self.nb_samples_per_mlp,
1771             device=self.device,
1772             batch_size=64,
1773             nb_epochs=250,
1774             nb_mlps_per_batch=1024,
1775         )
1776
1777         self.train_input = seq[:nb_train_samples]
1778         self.train_q_test_set = q_test_set[:nb_train_samples]
1779         self.train_ref_test_errors = test_error[:nb_train_samples]
1780         self.test_input = seq[nb_train_samples:]
1781         self.test_q_test_set = q_test_set[nb_train_samples:]
1782         self.test_ref_test_errors = test_error[nb_train_samples:]
1783
1784         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1785         with open(filename, "w") as f:
1786             for e in self.train_ref_test_errors:
1787                 f.write(f"{e}\n")
1788
1789         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1790         with open(filename, "w") as f:
1791             for e in self.test_ref_test_errors:
1792                 f.write(f"{e}\n")
1793
1794         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1795
1796     def batches(self, split="train"):
1797         assert split in {"train", "test"}
1798         input = self.train_input if split == "train" else self.test_input
1799         for batch in tqdm.tqdm(
1800             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1801         ):
1802             yield batch
1803
1804     def vocabulary_size(self):
1805         return self.nb_codes
1806
1807     def produce_results(
1808         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1809     ):
1810         correct = self.test_input[:1000]
1811         result = correct.clone()
1812         ar_mask = (
1813             torch.arange(result.size(1), device=result.device)
1814             > self.nb_samples_per_mlp * 3 + 1
1815         ).long()[None, :]
1816         ar_mask = ar_mask.expand_as(result)
1817         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1818
1819         masked_inplace_autoregression(
1820             model,
1821             self.batch_size,
1822             result,
1823             ar_mask,
1824             deterministic_synthesis,
1825             device=self.device,
1826         )
1827
1828         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1829         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1830         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1831
1832         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1833         with open(filename, "w") as f:
1834             for e in error_test:
1835                 f.write(f"{e}\n")
1836
1837
1838 ######################################################################