0858282fa3b5cab185d14025d5ad758de44411c2
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 try:
18     from graph import save_attention_image
19 except ImportError:
20     save_attention_image = None
21
22 ######################################################################
23
24
25 def masked_inplace_autoregression(
26     model,
27     batch_size,
28     input,
29     ar_mask,
30     deterministic_synthesis,
31     forbidden_tokens=None,
32     progress_bar_desc="autoregression",
33     device=torch.device("cpu"),
34 ):
35     assert input.size() == ar_mask.size()
36
37     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
38
39     if progress_bar_desc is not None:
40         batches = tqdm.tqdm(
41             batches,
42             dynamic_ncols=True,
43             desc=progress_bar_desc,
44             total=(input.size(0) + batch_size - 1) // batch_size,
45         )
46
47     with torch.autograd.no_grad():
48         t = model.training
49         model.eval()
50
51         for input, ar_mask in batches:
52             model.masked_inplace_autoregression(
53                 input, ar_mask, forbidden_tokens, deterministic_synthesis
54             )
55
56         model.train(t)
57
58
59 ######################################################################
60
61
62 class Task:
63     def batches(self, split="train"):
64         pass
65
66     def vocabulary_size(self):
67         pass
68
69     def produce_results(
70         self, n_epoch, model, result_dir, logger, deterministic_synthesis
71     ):
72         pass
73
74
75 ####################
76
77 import problems
78
79
80 class SandBox(Task):
81     def __init__(
82         self,
83         problem,
84         nb_train_samples,
85         nb_test_samples,
86         batch_size,
87         logger=None,
88         device=torch.device("cpu"),
89         max_nb_codes=1024,
90     ):
91         super().__init__()
92
93         self.batch_size = batch_size
94         self.device = device
95         self.problem = problem
96
97         self.train_input, self.train_ar_mask = self.problem.generate_sequences(
98             nb_train_samples
99         )
100         self.test_input, self.test_ar_mask = self.problem.generate_sequences(
101             nb_test_samples
102         )
103
104         self.train_input, self.train_ar_mask = self.train_input.to(
105             device
106         ), self.train_ar_mask.to(device)
107         self.test_input, self.test_ar_mask = self.test_input.to(
108             device
109         ), self.test_ar_mask.to(device)
110
111         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
112
113
114         # A bit of paranoia never hurts
115         assert (
116             self.nb_codes <= max_nb_codes
117             and self.train_input.min() >= 0
118             and self.test_input.min() >= 0
119             and tuple(x.item() for x in self.train_ar_mask.unique()) in { (0,), (1,), (0,1) }
120             and tuple(x.item() for x in self.test_ar_mask.unique()) in { (0,), (1,), (0,1) }
121         )
122
123     def batches(self, split="train", nb_to_use=-1, desc=None):
124         assert split in {"train", "test"}
125         input = self.train_input if split == "train" else self.test_input
126         if nb_to_use > 0:
127             input = input[:nb_to_use]
128         if desc is None:
129             desc = f"epoch-{split}"
130         for batch in tqdm.tqdm(
131             input.split(self.batch_size), dynamic_ncols=True, desc=desc
132         ):
133             yield batch
134
135     def vocabulary_size(self):
136         return self.nb_codes
137
138     def produce_results(
139         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
140     ):
141         def compute_accuracy(input, ar_mask, logger=None):
142             input, ar_mask = input[:nmax], ar_mask[:nmax]
143             result = input.clone() * (1 - ar_mask)
144
145             masked_inplace_autoregression(
146                 model,
147                 self.batch_size,
148                 result,
149                 ar_mask,
150                 deterministic_synthesis,
151                 progress_bar_desc=None,
152                 device=self.device,
153             )
154
155             if logger is not None:
156                 for sp, st in zip(result[:10], input[:10]):
157                     logger(
158                         f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
159                     )
160                     logger(
161                         f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
162                     )
163
164             nb_total, nb_correct = self.problem.compute_nb_correct(input, ar_mask, result)
165
166             # nb_total = ar_mask.sum().item()
167             # nb_correct = ((result == input).long() * ar_mask).sum().item()
168
169             return nb_total, nb_correct
170
171         train_nb_total, train_nb_correct = compute_accuracy(
172             self.train_input, self.train_ar_mask
173         )
174
175         logger(
176             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
177         )
178
179         test_nb_total, test_nb_correct = compute_accuracy(
180             self.test_input, self.test_ar_mask, logger
181         )
182
183         logger(
184             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
185         )
186
187         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
188
189         if save_attention_image is None:
190             logger("no save_attention_image (is pycairo installed?)")
191         else:
192             for k in range(10):
193                 ns = torch.randint(self.test_input.size(0), (1,)).item()
194                 input = self.test_input[ns : ns + 1].clone()
195
196                 with torch.autograd.no_grad():
197                     t = model.training
198                     model.eval()
199                     model.record_attention(True)
200                     model(BracketedSequence(input))
201                     model.train(t)
202                     ram = model.retrieve_attention()
203                     model.record_attention(False)
204
205                 tokens_output = [c for c in self.problem.seq2str(input[0])]
206                 tokens_input = ["n/a"] + tokens_output[:-1]
207                 for n_head in range(ram[0].size(1)):
208                     filename = os.path.join(
209                         result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
210                     )
211                     attention_matrices = [m[0, n_head] for m in ram]
212                     save_attention_image(
213                         filename,
214                         tokens_input,
215                         tokens_output,
216                         attention_matrices,
217                         k_top=10,
218                         # min_total_attention=0.9,
219                         token_gap=12,
220                         layer_gap=50,
221                     )
222                     logger(f"wrote {filename}")
223
224
225 ######################################################################
226
227 import picoclvr
228
229
230 class PicoCLVR(Task):
231     # Make a tensor from a list of strings
232     def tensorize(self, descr):
233         token_descr = [s.strip().split(" ") for s in descr]
234         l = max([len(s) for s in token_descr])
235         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
236         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
237         return torch.tensor(id_descr, device=self.device)
238
239     # Make a list of strings from a tensor
240     def detensorize(self, x):
241         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
242
243     # trim all the tensors in the tuple z to remove as much token from
244     # left and right in the first tensor. If z is a tuple, all its
245     # elements are trimed according to the triming for the first
246     def trim(self, z, token="<nul>"):
247         n = self.token2id[token]
248         if type(z) == tuple:
249             x = z[0]
250             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
251             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
252             return tuple([t[:, a:b] for t in z])
253         else:
254             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
255             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
256             return z[:, a:b]
257
258     ######################
259
260     def __init__(
261         self,
262         nb_train_samples,
263         nb_test_samples,
264         batch_size,
265         height,
266         width,
267         nb_colors=5,
268         logger=None,
269         device=torch.device("cpu"),
270         pruner_train=None,
271         pruner_eval=None,
272     ):
273         super().__init__()
274
275         def generate_descr(nb, cache_suffix, pruner):
276             return picoclvr.generate(
277                 nb,
278                 height=self.height,
279                 width=self.width,
280                 nb_colors=nb_colors,
281                 pruner=pruner,
282             )
283
284         self.height = height
285         self.width = width
286         self.batch_size = batch_size
287         self.device = device
288         self.pruner_train = pruner_train
289         self.pruner_eval = pruner_eval
290
291         if logger is not None:
292             logger(
293                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
294             )
295
296         self.train_descr = generate_descr(
297             nb_train_samples, "train", pruner=self.pruner_train
298         )
299         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
300
301         # Build the tokenizer
302         tokens = {"<nul>", "<img>"}
303         for d in [self.train_descr, self.test_descr]:
304             for s in d:
305                 for t in s.strip().split(" "):
306                     tokens.add(t)
307         # make this set a sorted list to get the same tensors given
308         # the same descr
309         tokens = list(tokens)
310         tokens.sort()
311         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
312         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
313         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
314
315         # Tokenize the train and test sets
316         self.train_input = self.tensorize(self.train_descr)
317         self.test_input = self.tensorize(self.test_descr)
318
319     def batches(self, split="train"):
320         assert split in {"train", "test"}
321         input = self.train_input if split == "train" else self.test_input
322         for batch in tqdm.tqdm(
323             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
324         ):
325             yield self.trim(batch)
326
327     def vocabulary_size(self):
328         return len(self.token2id)
329
330     def compute_missing_properties(
331         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
332     ):
333         acc_nb_requested_properties = []
334         acc_nb_missing_properties = []
335         acc_nb_results = 0
336
337         for input in tqdm.tqdm(
338             self.test_input.split(self.batch_size),
339             dynamic_ncols=True,
340             desc=f"test-properties",
341         ):
342             result = input.clone()
343             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
344             result = (1 - ar_mask) * result + ar_mask * self.t_nul
345             masked_inplace_autoregression(
346                 model,
347                 self.batch_size,
348                 result,
349                 ar_mask,
350                 deterministic_synthesis,
351                 progress_bar_desc=None,
352                 device=self.device,
353             )
354
355             result_descr = self.detensorize(result)
356             np = picoclvr.nb_properties(
357                 result_descr,
358                 height=self.height,
359                 width=self.width,
360                 pruner=pruner,
361             )
362             nb_requested_properties, _, nb_missing_properties = zip(*np)
363             acc_nb_requested_properties += nb_requested_properties
364             acc_nb_missing_properties += nb_missing_properties
365             acc_nb_results += len(result_descr)
366
367         nb_requested_properties = sum(acc_nb_requested_properties)
368         nb_missing_properties = sum(acc_nb_missing_properties)
369
370         prefix = "" if pruner is None else "pruned_"
371         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
372         logger(
373             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
374         )
375         logger(
376             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
377         )
378
379         logger(
380             f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
381         )
382
383     ######################################################################
384
385     def produce_results(
386         self, n_epoch, model, result_dir, logger, deterministic_synthesis
387     ):
388         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
389
390         if self.pruner_eval is not None:
391             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
392
393         nb_tokens_to_generate = self.height * self.width + 3
394         result_descr = []
395         nb_per_primer = 8
396         primer = []
397
398         for primer_descr in [
399             "red above green <sep> green top <sep> blue right of red",
400             "there is red <sep> there is yellow <sep> there is blue",
401             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
402             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
403         ]:
404             primer += [primer_descr + " <img>"] * nb_per_primer
405
406         result = self.tensorize(primer)
407         fill = result.new_full(
408             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
409         )
410         result = torch.cat((result, fill), 1)
411         ar_mask = (result == self.t_nul).long()
412         masked_inplace_autoregression(
413             model,
414             self.batch_size,
415             result,
416             ar_mask,
417             deterministic_synthesis,
418             device=self.device,
419         )
420         result_descr = self.detensorize(result)
421
422         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
423
424         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
425         acc_nb_results = len(result_descr)
426
427         nb_requested_properties = sum(acc_nb_requested_properties)
428         nb_missing_properties = sum(acc_nb_missing_properties)
429
430         prefix = "demo_"
431         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
432         logger(
433             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
434         )
435         logger(
436             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
437         )
438
439         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
440
441         if img.dim() == 5:
442             if img.size(1) == 1:
443                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
444             else:
445                 img = torch.cat(
446                     [
447                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
448                         for x in img
449                     ],
450                     0,
451                 )
452
453         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
454         torchvision.utils.save_image(
455             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
456         )
457         logger(f"wrote {image_name}")
458
459
460 ######################################################################
461
462
463 class MNIST(Task):
464     def __init__(
465         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
466     ):
467         super().__init__()
468
469         self.nb_train_samples = (nb_train_samples,)
470         self.nb_test_samples = (nb_test_samples,)
471         self.batch_size = batch_size
472         self.device = device
473         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
474         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
475         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
476         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
477
478     def batches(self, split="train", nb_to_use=-1, desc=None):
479         assert split in {"train", "test"}
480         input = self.train_input if split == "train" else self.test_input
481         if nb_to_use > 0:
482             input = input[:nb_to_use]
483         if desc is None:
484             desc = f"epoch-{split}"
485         for batch in tqdm.tqdm(
486             input.split(self.batch_size), dynamic_ncols=True, desc=desc
487         ):
488             yield batch
489
490     def vocabulary_size(self):
491         return 256
492
493     def produce_results(
494         self, n_epoch, model, result_dir, logger, deterministic_synthesis
495     ):
496         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
497         ar_mask = torch.full_like(results, 1)
498         masked_inplace_autoregression(
499             model,
500             self.batch_size,
501             results,
502             ar_mask,
503             deterministic_synthesis,
504             device=self.device,
505         )
506         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
507         torchvision.utils.save_image(
508             1 - results.reshape(-1, 1, 28, 28) / 255.0,
509             image_name,
510             nrow=16,
511             pad_value=0.8,
512         )
513         logger(f"wrote {image_name}")
514
515
516 ######################################################################
517
518 import maze
519
520
521 class Maze(Task):
522     def map2seq(self, *m):
523         return torch.cat([x.flatten(1) for x in m], 1)
524
525     def seq2map(self, s):
526         s = s.reshape(s.size(0), -1, self.height, self.width)
527         return (s[:, k] for k in range(s.size(1)))
528
529     def __init__(
530         self,
531         nb_train_samples,
532         nb_test_samples,
533         batch_size,
534         height,
535         width,
536         nb_walls,
537         device=torch.device("cpu"),
538     ):
539         super().__init__()
540
541         self.batch_size = batch_size
542         self.height = height
543         self.width = width
544         self.device = device
545
546         train_mazes, train_paths, _ = maze.create_maze_data(
547             nb_train_samples,
548             height=height,
549             width=width,
550             nb_walls=nb_walls,
551             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
552         )
553         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
554
555         test_mazes, test_paths, _ = maze.create_maze_data(
556             nb_test_samples,
557             height=height,
558             width=width,
559             nb_walls=nb_walls,
560             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
561         )
562         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
563
564         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
565
566     def batches(self, split="train", nb_to_use=-1, desc=None):
567         assert split in {"train", "test"}
568         input = self.train_input if split == "train" else self.test_input
569         if nb_to_use > 0:
570             input = input[:nb_to_use]
571         if desc is None:
572             desc = f"epoch-{split}"
573         for batch in tqdm.tqdm(
574             input.split(self.batch_size), dynamic_ncols=True, desc=desc
575         ):
576             yield batch
577
578     def vocabulary_size(self):
579         return self.nb_codes
580
581     def compute_error(
582         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
583     ):
584         nb_total, nb_correct = 0, 0
585         count = torch.zeros(
586             self.width * self.height,
587             self.width * self.height,
588             device=self.device,
589             dtype=torch.int64,
590         )
591
592         for input in self.batches(split, nb_to_use):
593             result = input.clone()
594             ar_mask = result.new_zeros(result.size())
595             ar_mask[:, self.height * self.width :] = 1
596             result *= 1 - ar_mask
597             masked_inplace_autoregression(
598                 model,
599                 self.batch_size,
600                 result,
601                 ar_mask,
602                 deterministic_synthesis,
603                 progress_bar_desc=None,
604                 device=self.device,
605             )
606             mazes, paths = self.seq2map(result)
607             path_correctness = maze.path_correctness(mazes, paths)
608             nb_correct += path_correctness.long().sum()
609             nb_total += mazes.size(0)
610
611             optimal_path_lengths = (
612                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
613             )
614             predicted_path_lengths = (
615                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
616             )
617             optimal_path_lengths = optimal_path_lengths[path_correctness]
618             predicted_path_lengths = predicted_path_lengths[path_correctness]
619             count[optimal_path_lengths, predicted_path_lengths] += 1
620
621         if count.max() == 0:
622             count = None
623         else:
624             count = count[
625                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
626             ]
627
628         return nb_total, nb_correct, count
629
630     def produce_results(
631         self, n_epoch, model, result_dir, logger, deterministic_synthesis
632     ):
633         train_nb_total, train_nb_correct, count = self.compute_error(
634             model,
635             "train",
636             nb_to_use=1000,
637             deterministic_synthesis=deterministic_synthesis,
638         )
639         logger(
640             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
641         )
642
643         test_nb_total, test_nb_correct, count = self.compute_error(
644             model,
645             "test",
646             nb_to_use=1000,
647             deterministic_synthesis=deterministic_synthesis,
648         )
649         logger(
650             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
651         )
652
653         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
654
655         if count is not None:
656             proportion_optimal = count.diagonal().sum().float() / count.sum()
657             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
658             with open(
659                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
660             ) as f:
661                 for i in range(count.size(0)):
662                     for j in range(count.size(1)):
663                         eol = " " if j < count.size(1) - 1 else "\n"
664                         f.write(f"{count[i,j]}{eol}")
665
666         input = self.test_input[:48]
667         result = input.clone()
668         ar_mask = result.new_zeros(result.size())
669         ar_mask[:, self.height * self.width :] = 1
670         result *= 1 - ar_mask
671         masked_inplace_autoregression(
672             model,
673             self.batch_size,
674             result,
675             ar_mask,
676             deterministic_synthesis,
677             device=self.device,
678         )
679
680         mazes, paths = self.seq2map(input)
681         _, predicted_paths = self.seq2map(result)
682
683         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
684         maze.save_image(
685             filename,
686             mazes=mazes,
687             target_paths=paths,
688             predicted_paths=predicted_paths,
689             path_correct=maze.path_correctness(mazes, predicted_paths),
690             path_optimal=maze.path_optimality(paths, predicted_paths),
691         )
692         logger(f"wrote {filename}")
693
694
695 ######################################################################
696
697
698 import snake
699
700
701 class Snake(Task):
702     def __init__(
703         self,
704         nb_train_samples,
705         nb_test_samples,
706         batch_size,
707         height,
708         width,
709         nb_colors,
710         length,
711         prompt_length,
712         device=torch.device("cpu"),
713     ):
714         super().__init__()
715
716         self.batch_size = batch_size
717         self.height = height
718         self.width = width
719         self.device = device
720         self.prompt_length = prompt_length
721
722         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
723             nb_train_samples,
724             height,
725             width,
726             nb_colors,
727             length,
728             prompt_length,
729             self.device,
730         )
731         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
732             nb_test_samples,
733             height,
734             width,
735             nb_colors,
736             length,
737             prompt_length,
738             self.device,
739         )
740
741         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
742
743     def batches(self, split="train", nb_to_use=-1, desc=None):
744         assert split in {"train", "test"}
745         input = self.train_input if split == "train" else self.test_input
746         if nb_to_use > 0:
747             input = input[:nb_to_use]
748         if desc is None:
749             desc = f"epoch-{split}"
750         for batch in tqdm.tqdm(
751             input.split(self.batch_size), dynamic_ncols=True, desc=desc
752         ):
753             yield batch
754
755     def vocabulary_size(self):
756         return self.nb_codes
757
758     def produce_results(
759         self, n_epoch, model, result_dir, logger, deterministic_synthesis
760     ):
761         def compute_nb_correct(input, prior_visits):
762             result = input.clone()
763             i = torch.arange(result.size(1), device=result.device)[None, :]
764             ar_mask = (
765                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
766                 .long()
767                 .expand_as(result)
768             )
769             result *= 1 - ar_mask
770
771             masked_inplace_autoregression(
772                 model,
773                 self.batch_size,
774                 result,
775                 ar_mask,
776                 deterministic_synthesis,
777                 device=self.device,
778             )
779
780             nb_total = ((prior_visits > 0) * ar_mask).sum()
781
782             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
783
784             return nb_total, nb_correct
785
786         test_nb_total, test_nb_correct = compute_nb_correct(
787             self.test_input[:1000], self.test_prior_visits[:1000]
788         )
789
790         logger(
791             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
792         )
793
794         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
795
796
797 ######################################################################
798
799
800 import stack
801
802
803 class Stack(Task):
804     def __init__(
805         self,
806         nb_train_samples,
807         nb_test_samples,
808         batch_size,
809         logger,
810         nb_steps,
811         nb_stacks,
812         nb_digits,
813         fraction_values_for_train=None,
814         device=torch.device("cpu"),
815     ):
816         super().__init__()
817
818         self.batch_size = batch_size
819         self.nb_steps = nb_steps
820         self.nb_stacks = nb_stacks
821         self.nb_digits = nb_digits
822         self.device = device
823
824         if fraction_values_for_train is None:
825             values_for_train = None
826             values_for_test = None
827         else:
828             all = torch.randperm(10**nb_digits)
829             nb_for_train = int(all.size(0) * fraction_values_for_train)
830             values_for_train = all[:nb_for_train]
831             values_for_test = all[nb_for_train:]
832
833         self.train_input, self.train_stack_counts = stack.generate_sequences(
834             nb_train_samples,
835             nb_steps,
836             nb_stacks,
837             nb_digits,
838             values_for_train,
839             self.device,
840         )
841
842         self.test_input, self.test_stack_counts = stack.generate_sequences(
843             nb_test_samples,
844             nb_steps,
845             nb_stacks,
846             nb_digits,
847             values_for_test,
848             self.device,
849         )
850
851         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
852         counts = self.test_stack_counts.flatten()[i.flatten()]
853         counts = F.one_hot(counts).sum(0)
854         logger(f"test_pop_stack_counts {counts}")
855
856         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
857
858     def batches(self, split="train", nb_to_use=-1, desc=None):
859         assert split in {"train", "test"}
860         input = self.train_input if split == "train" else self.test_input
861         if nb_to_use > 0:
862             input = input[:nb_to_use]
863         if desc is None:
864             desc = f"epoch-{split}"
865         for batch in tqdm.tqdm(
866             input.split(self.batch_size), dynamic_ncols=True, desc=desc
867         ):
868             yield batch
869
870     def vocabulary_size(self):
871         return self.nb_codes
872
873     def produce_results(
874         self, n_epoch, model, result_dir, logger, deterministic_synthesis
875     ):
876         def compute_nb_correct(input):
877             result = input.clone()
878             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
879             ar_mask = (result != input).long()
880             masked_inplace_autoregression(
881                 model,
882                 self.batch_size,
883                 result,
884                 ar_mask,
885                 deterministic_synthesis,
886                 device=self.device,
887             )
888
889             errors = ((result != input).long() * ar_mask).reshape(
890                 -1, 1 + self.nb_digits
891             )
892             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
893
894             nb_total = ar_mask.max(1).values.sum()
895             nb_correct = nb_total - errors.max(1).values.sum()
896
897             return nb_total, nb_correct
898
899         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
900
901         logger(
902             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
903         )
904
905         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
906
907         ##############################################################
908         # Log a few generated sequences
909         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
910         result = input.clone()
911         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
912         ar_mask = (result != input).long()
913
914         # for n in range(result.size(0)):
915         # logger(
916         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
917         # )
918
919         masked_inplace_autoregression(
920             model,
921             self.batch_size,
922             result,
923             ar_mask,
924             deterministic_synthesis,
925             device=self.device,
926         )
927
928         for n in range(result.size(0)):
929             logger(
930                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
931             )
932         ##############################################################
933
934
935 ######################################################################
936
937 import rpl
938
939
940 class RPL(Task):
941     def tensorize(self, sequences):
942         len_max = max([len(x) for x in sequences])
943         return torch.cat(
944             [
945                 torch.tensor(
946                     [
947                         [
948                             self.token2id[str(c)]
949                             for c in s + ["<nul>"] * (len_max - len(s))
950                         ]
951                         for s in sequences
952                     ]
953                 )
954             ],
955             0,
956         )
957
958     def seq2str(self, seq):
959         return " ".join([self.id2token[i] for i in seq])
960
961     def __init__(
962         self,
963         nb_train_samples,
964         nb_test_samples,
965         batch_size,
966         nb_starting_values=3,
967         max_input=9,
968         prog_len=6,
969         nb_runs=5,
970         no_prog=False,
971         logger=None,
972         device=torch.device("cpu"),
973     ):
974         super().__init__()
975
976         self.batch_size = batch_size
977         self.device = device
978         self.no_prog = no_prog
979
980         train_sequences = [
981             rpl.generate(
982                 nb_starting_values=nb_starting_values,
983                 nb_result_values_max=4 * nb_starting_values,
984                 max_input=max_input,
985                 prog_len=prog_len,
986                 nb_runs=nb_runs,
987             )
988             for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
989         ]
990
991         test_sequences = [
992             rpl.generate(
993                 nb_starting_values=nb_starting_values,
994                 nb_result_values_max=4 * nb_starting_values,
995                 max_input=max_input,
996                 prog_len=prog_len,
997                 nb_runs=nb_runs,
998             )
999             for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
1000         ]
1001
1002         symbols = list(
1003             set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
1004         )
1005         val_max = max([x if type(x) is int else 0 for x in symbols])
1006         symbols = list(filter(lambda x: type(x) is str, symbols))
1007         symbols.sort()
1008         symbols += [str(n) for n in range(val_max + 1)]
1009         self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
1010         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
1011
1012         self.t_nul = self.token2id["<nul>"]
1013         self.t_input = self.token2id["<in>"]
1014         self.t_output = self.token2id["<out>"]
1015         self.t_prog = self.token2id["<prg>"]
1016         self.t_end = self.token2id["<end>"]
1017
1018         self.train_input = self.tensorize(train_sequences)
1019         self.test_input = self.tensorize(test_sequences)
1020
1021         if no_prog:
1022             # Excise the program from every train and test example
1023             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
1024                 None, :
1025             ]
1026             p = (
1027                 ((self.train_input == self.t_prog).long() * k)
1028                 .max(1, keepdim=True)
1029                 .values
1030             )
1031             self.train_input = (
1032                 self.train_input * (k <= p).long()
1033                 + self.t_end * (k == p + 1).long()
1034                 + self.t_nul * (k > p + 1).long()
1035             )
1036             k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
1037                 None, :
1038             ]
1039             p = (
1040                 ((self.test_input == self.t_prog).long() * k)
1041                 .max(1, keepdim=True)
1042                 .values
1043             )
1044             self.test_input = (
1045                 self.test_input * (k <= p).long()
1046                 + self.t_end * (k == p + 1).long()
1047                 + self.t_nul * (k > p + 1).long()
1048             )
1049
1050         if logger is not None:
1051             logger(f"value_max {val_max}")
1052             for x in self.train_input[:25]:
1053                 end = (x != self.t_nul).nonzero().max().item() + 1
1054                 seq = [self.id2token[i.item()] for i in x[:end]]
1055                 s = " ".join(seq)
1056                 logger(f"example_seq {s}")
1057
1058         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1059
1060     def batches(self, split="train", nb_to_use=-1, desc=None):
1061         assert split in {"train", "test"}
1062         input = self.train_input if split == "train" else self.test_input
1063         if nb_to_use > 0:
1064             input = input[:nb_to_use]
1065         if desc is None:
1066             desc = f"epoch-{split}"
1067         for batch in tqdm.tqdm(
1068             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1069         ):
1070             last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
1071             batch = batch[:, :last].to(self.device)
1072             yield batch
1073
1074     def vocabulary_size(self):
1075         return self.nb_codes
1076
1077     def produce_results(
1078         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1079     ):
1080         # --------------------------------------------------------------------
1081         def compute_nb_errors_prog(input, nb_to_log=0):
1082             result = input.clone()
1083             s = (result == self.t_prog).long()
1084             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1085             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1086
1087             masked_inplace_autoregression(
1088                 model,
1089                 self.batch_size,
1090                 result,
1091                 ar_mask,
1092                 deterministic_synthesis,
1093                 device=self.device,
1094             )
1095
1096             sum_nb_total, sum_nb_errors = 0, 0
1097             for one_input, one_result in zip(input, result):
1098                 seq = [self.id2token[i.item()] for i in one_result]
1099                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
1100                 sum_nb_total += 1
1101                 sum_nb_errors += 0 if nb_errors == 0 else 1
1102                 if nb_to_log > 0:
1103                     gt_seq = [self.id2token[i.item()] for i in one_input]
1104                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
1105                     gt_prog = " ".join([str(x) for x in gt_prog])
1106                     prog = " ".join([str(x) for x in prog])
1107                     comment = "*" if nb_errors == 0 else "-"
1108                     logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
1109                     for start_stack, target_stack, result_stack, correct in stacks:
1110                         comment = "*" if correct else "-"
1111                         start_stack = " ".join([str(x) for x in start_stack])
1112                         target_stack = " ".join([str(x) for x in target_stack])
1113                         result_stack = " ".join([str(x) for x in result_stack])
1114                         logger(
1115                             f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
1116                         )
1117                     nb_to_log -= 1
1118
1119             return sum_nb_total, sum_nb_errors
1120
1121         # --------------------------------------------------------------------
1122         def compute_nb_errors_output(input, nb_to_log=0):
1123             result = input.clone()
1124             k = torch.arange(result.size(1), device=result.device)[None, :]
1125             last_output_idx = (
1126                 ((result == self.t_output) * k).max(dim=1, keepdim=True).values
1127             )
1128             first_prog_idx = (
1129                 ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
1130             )
1131             ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
1132             result = (1 - ar_mask) * result + ar_mask * self.t_nul
1133
1134             masked_inplace_autoregression(
1135                 model,
1136                 self.batch_size,
1137                 result,
1138                 ar_mask,
1139                 deterministic_synthesis,
1140                 device=self.device,
1141             )
1142
1143             sum_nb_total, sum_nb_errors = 0, 0
1144             for one_input, one_result, i, j in zip(
1145                 input, result, last_output_idx, first_prog_idx
1146             ):
1147                 seq = [self.id2token[i.item()] for i in one_result]
1148                 sum_nb_total += 1
1149                 correct = (one_input - one_result).abs().max() == 0
1150                 sum_nb_errors += 0 if correct else 1
1151                 if nb_to_log > 0:
1152                     result_stack = [
1153                         self.id2token[i.item()] for i in one_result[i : j + 1]
1154                     ]
1155                     target_stack = [
1156                         self.id2token[i.item()] for i in one_input[i : j + 1]
1157                     ]
1158                     comment = "*" if correct else "-"
1159                     result_stack = " ".join([str(x) for x in result_stack])
1160                     target_stack = " ".join([str(x) for x in target_stack])
1161                     logger(
1162                         f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
1163                     )
1164                     nb_to_log -= 1
1165
1166             return sum_nb_total, sum_nb_errors
1167
1168         # --------------------------------------------------------------------
1169
1170         if not self.no_prog:
1171             test_nb_total, test_nb_errors = compute_nb_errors_prog(
1172                 self.test_input[:1000].to(self.device), nb_to_log=10
1173             )
1174
1175             logger(
1176                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1177             )
1178
1179             logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
1180
1181         test_nb_total, test_nb_errors = compute_nb_errors_output(
1182             self.test_input[:1000].to(self.device), nb_to_log=10
1183         )
1184
1185         logger(
1186             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
1187         )
1188
1189         if save_attention_image is None:
1190             logger("no save_attention_image (is pycairo installed?)")
1191         else:
1192             ns = torch.randint(self.test_input.size(0), (1,)).item()
1193             input = self.test_input[ns : ns + 1].clone()
1194             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
1195             input = input[:, :last].to(self.device)
1196
1197             with torch.autograd.no_grad():
1198                 t = model.training
1199                 model.eval()
1200                 model.record_attention(True)
1201                 model(BracketedSequence(input))
1202                 model.train(t)
1203                 ram = model.retrieve_attention()
1204                 model.record_attention(False)
1205
1206             tokens_output = [self.id2token[i.item()] for i in input[0]]
1207             tokens_input = ["n/a"] + tokens_output[:-1]
1208             for n_head in range(ram[0].size(1)):
1209                 filename = os.path.join(
1210                     result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
1211                 )
1212                 attention_matrices = [m[0, n_head] for m in ram]
1213                 save_attention_image(
1214                     filename,
1215                     tokens_input,
1216                     tokens_output,
1217                     attention_matrices,
1218                     k_top=10,
1219                     # min_total_attention=0.9,
1220                     token_gap=12,
1221                     layer_gap=50,
1222                 )
1223                 logger(f"wrote {filename}")
1224
1225
1226 ######################################################################
1227
1228
1229 import expr
1230
1231
1232 class Expr(Task):
1233     def tensorize(self, sequences):
1234         len_max = max([len(x) for x in sequences])
1235         return torch.cat(
1236             [
1237                 torch.tensor(
1238                     [
1239                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
1240                         for s in sequences
1241                     ]
1242                 )
1243             ],
1244             0,
1245         ).to(self.device)
1246
1247     def __init__(
1248         self,
1249         nb_train_samples,
1250         nb_test_samples,
1251         nb_variables,
1252         sequence_length,
1253         operand_max,
1254         result_max,
1255         batch_size,
1256         device=torch.device("cpu"),
1257     ):
1258         super().__init__()
1259
1260         self.batch_size = batch_size
1261         self.device = device
1262
1263         train_sequences = expr.generate_sequences(
1264             nb_train_samples,
1265             nb_variables=nb_variables,
1266             length=sequence_length,
1267             operand_max=operand_max,
1268             result_max=result_max,
1269         )
1270
1271         test_sequences = expr.generate_sequences(
1272             nb_test_samples,
1273             nb_variables=nb_variables,
1274             length=sequence_length,
1275             operand_max=operand_max,
1276             result_max=result_max,
1277         )
1278
1279         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
1280         symbols.sort()
1281
1282         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
1283         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
1284
1285         self.filler, self.space = self.char2id["#"], self.char2id[" "]
1286
1287         self.train_input = self.tensorize(train_sequences)
1288         self.test_input = self.tensorize(test_sequences)
1289
1290         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1291
1292     def batches(self, split="train", nb_to_use=-1, desc=None):
1293         assert split in {"train", "test"}
1294         input = self.train_input if split == "train" else self.test_input
1295         if nb_to_use > 0:
1296             input = input[:nb_to_use]
1297         if desc is None:
1298             desc = f"epoch-{split}"
1299         for batch in tqdm.tqdm(
1300             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1301         ):
1302             last = (batch != self.filler).max(0).values.nonzero().max() + 3
1303             batch = batch[:, :last]
1304             yield batch
1305
1306     def vocabulary_size(self):
1307         return self.nb_codes
1308
1309     def seq2str(self, s):
1310         return "".join([self.id2char[k.item()] for k in s])
1311
1312     def produce_results(
1313         self,
1314         n_epoch,
1315         model,
1316         result_dir,
1317         logger,
1318         deterministic_synthesis,
1319         input_file=None,
1320     ):
1321         def compute_nb_correct(input):
1322             result = input.clone()
1323             s = (result == self.space).long()
1324             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1325             result = (1 - ar_mask) * result + ar_mask * self.filler
1326             masked_inplace_autoregression(
1327                 model,
1328                 self.batch_size,
1329                 result,
1330                 ar_mask,
1331                 deterministic_synthesis,
1332                 device=self.device,
1333             )
1334
1335             nb_total = input.size(0)
1336             nb_correct = (input == result).long().min(1).values.sum()
1337
1338             #######################################################################
1339             # Comput predicted vs. true variable values
1340
1341             nb_delta = torch.zeros(5, dtype=torch.int64)
1342             nb_missed = 0
1343
1344             values_input = expr.extract_results([self.seq2str(s) for s in input])
1345             values_result = expr.extract_results([self.seq2str(s) for s in result])
1346
1347             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1348
1349             with open(filename, "w") as f:
1350                 for i, r in zip(values_input, values_result):
1351                     for n, vi in i.items():
1352                         vr = r.get(n)
1353                         f.write(f"{vi} {-1 if vr is None else vr}\n")
1354
1355                         if vr is None or vr < 0:
1356                             nb_missed += 1
1357                         else:
1358                             d = abs(vr - vi)
1359                             if d >= nb_delta.size(0):
1360                                 nb_missed += 1
1361                             else:
1362                                 nb_delta[d] += 1
1363
1364             ######################################################################
1365
1366             return nb_total, nb_correct, nb_delta, nb_missed
1367
1368         (
1369             test_nb_total,
1370             test_nb_correct,
1371             test_nb_delta,
1372             test_nb_missed,
1373         ) = compute_nb_correct(self.test_input[:10000])
1374
1375         logger(
1376             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1377         )
1378
1379         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
1380
1381         nb_total = test_nb_delta.sum() + test_nb_missed
1382         for d in range(test_nb_delta.size(0)):
1383             logger(
1384                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1385             )
1386         logger(
1387             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1388         )
1389
1390         ##############################################################
1391         # Log a few generated sequences
1392         if input_file is None:
1393             input = self.test_input[:10]
1394         else:
1395             with open(input_file, "r") as f:
1396                 sequences = [e.strip() for e in f.readlines()]
1397                 sequences = [s + " " + "#" * 50 for s in sequences]
1398                 input = self.tensorize(sequences)
1399
1400         result = input.clone()
1401         s = (result == self.space).long()
1402         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1403         result = (1 - ar_mask) * result + ar_mask * self.filler
1404
1405         for n in range(result.size(0)):
1406             logger(f"test_before {self.seq2str(result[n])}")
1407
1408         masked_inplace_autoregression(
1409             model,
1410             self.batch_size,
1411             result,
1412             ar_mask,
1413             deterministic_synthesis,
1414             device=self.device,
1415         )
1416
1417         correct = (1 - ar_mask) * self.space + ar_mask * input
1418         for n in range(result.size(0)):
1419             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1420             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1421             logger(f"truth       {self.seq2str(correct[n])}")
1422         ##############################################################
1423
1424
1425 ######################################################################
1426
1427 import grid
1428
1429
1430 class Grid(Task):
1431     # Make a tensor from a list of strings
1432     def str2tensor(self, descr):
1433         token_descr = [s.strip().split(" ") for s in descr]
1434         l = max([len(s) for s in token_descr])
1435         token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
1436         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
1437         return torch.tensor(id_descr, device=self.device)
1438
1439     # Make a list of strings from a tensor
1440     def tensor2str(self, x):
1441         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
1442
1443     # trim all the tensors in the tuple z to remove as much token from
1444     # left and right in the first tensor. If z is a tuple, all its
1445     # elements are trimed according to the triming for the first
1446     def trim(self, z, token="#"):
1447         n = self.token2id[token]
1448         if type(z) == tuple:
1449             x = z[0]
1450             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1451             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1452             return tuple([t[:, a:b] for t in z])
1453         else:
1454             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
1455             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
1456             return z[:, a:b]
1457
1458     ######################
1459
1460     def __init__(
1461         self,
1462         nb_train_samples,
1463         nb_test_samples,
1464         batch_size,
1465         size,
1466         logger=None,
1467         device=torch.device("cpu"),
1468     ):
1469         super().__init__()
1470
1471         self.device = device
1472         self.batch_size = batch_size
1473         self.grid_factory = grid.GridFactory(size=size)
1474
1475         if logger is not None:
1476             logger(
1477                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1478             )
1479
1480         self.train_descr = self.grid_factory.generate_samples(
1481             nb_train_samples, lambda r: tqdm.tqdm(r)
1482         )
1483         self.test_descr = self.grid_factory.generate_samples(
1484             nb_test_samples, lambda r: tqdm.tqdm(r)
1485         )
1486
1487         # Build the tokenizer
1488         tokens = set()
1489         for d in [self.train_descr, self.test_descr]:
1490             for s in d:
1491                 for t in s.strip().split(" "):
1492                     tokens.add(t)
1493         # make this set a sorted list to get the same tensors given
1494         # the same descr
1495         tokens = list(tokens)
1496         tokens.sort()
1497         tokens = ["#"] + tokens
1498         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
1499         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
1500         self.t_nul = self.token2id["#"]
1501         self.t_true = self.token2id["true"]
1502         self.t_false = self.token2id["false"]
1503
1504         # Tokenize the train and test sets
1505         self.train_input = self.str2tensor(self.train_descr)
1506         self.test_input = self.str2tensor(self.test_descr)
1507
1508     def batches(self, split="train"):
1509         assert split in {"train", "test"}
1510         input = self.train_input if split == "train" else self.test_input
1511         for batch in tqdm.tqdm(
1512             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1513         ):
1514             yield self.trim(batch)
1515
1516     def vocabulary_size(self):
1517         return len(self.token2id)
1518
1519     def produce_results(
1520         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1521     ):
1522         correct = self.test_input[:1000]
1523         result = correct.clone()
1524         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
1525         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1526
1527         logger(f"----------------------------------------------------------")
1528
1529         for e in self.tensor2str(result[:10]):
1530             logger(f"test_before {e}")
1531
1532         masked_inplace_autoregression(
1533             model,
1534             self.batch_size,
1535             result,
1536             ar_mask,
1537             deterministic_synthesis,
1538             device=self.device,
1539         )
1540
1541         logger(f"----------------------------------------------------------")
1542
1543         for e in self.tensor2str(result[:10]):
1544             logger(f"test_after  {e}")
1545
1546         logger(f"----------------------------------------------------------")
1547
1548         nb_total = ar_mask.sum().item()
1549         nb_correct = ((correct == result).long() * ar_mask).sum().item()
1550
1551         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
1552         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
1553
1554
1555 ######################################################################
1556
1557 import qmlp
1558
1559
1560 class QMLP(Task):
1561     ######################
1562
1563     def __init__(
1564         self,
1565         nb_train_samples,
1566         nb_test_samples,
1567         batch_size,
1568         result_dir,
1569         logger=None,
1570         device=torch.device("cpu"),
1571     ):
1572         super().__init__()
1573
1574         self.device = device
1575         self.batch_size = batch_size
1576         self.nb_samples_per_mlp = 256
1577
1578         if logger is not None:
1579             logger(
1580                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
1581             )
1582
1583         seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
1584             nb_mlps=nb_train_samples + nb_test_samples,
1585             nb_samples=self.nb_samples_per_mlp,
1586             device=self.device,
1587             batch_size=64,
1588             nb_epochs=250,
1589             nb_mlps_per_batch=1024,
1590         )
1591
1592         self.train_input = seq[:nb_train_samples]
1593         self.train_q_test_set = q_test_set[:nb_train_samples]
1594         self.train_ref_test_errors = test_error[:nb_train_samples]
1595         self.test_input = seq[nb_train_samples:]
1596         self.test_q_test_set = q_test_set[nb_train_samples:]
1597         self.test_ref_test_errors = test_error[nb_train_samples:]
1598
1599         filename = os.path.join(result_dir, f"train_errors_ref.dat")
1600         with open(filename, "w") as f:
1601             for e in self.train_ref_test_errors:
1602                 f.write(f"{e}\n")
1603
1604         filename = os.path.join(result_dir, f"test_errors_ref.dat")
1605         with open(filename, "w") as f:
1606             for e in self.test_ref_test_errors:
1607                 f.write(f"{e}\n")
1608
1609         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
1610
1611     def batches(self, split="train"):
1612         assert split in {"train", "test"}
1613         input = self.train_input if split == "train" else self.test_input
1614         for batch in tqdm.tqdm(
1615             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
1616         ):
1617             yield batch
1618
1619     def vocabulary_size(self):
1620         return self.nb_codes
1621
1622     def produce_results(
1623         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1624     ):
1625         correct = self.test_input[:1000]
1626         result = correct.clone()
1627         ar_mask = (
1628             torch.arange(result.size(1), device=result.device)
1629             > self.nb_samples_per_mlp * 3 + 1
1630         ).long()[None, :]
1631         ar_mask = ar_mask.expand_as(result)
1632         result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
1633
1634         masked_inplace_autoregression(
1635             model,
1636             self.batch_size,
1637             result,
1638             ar_mask,
1639             deterministic_synthesis,
1640             device=self.device,
1641         )
1642
1643         q_train_set = result[:, : self.nb_samples_per_mlp * 3]
1644         q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
1645         error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
1646
1647         filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
1648         with open(filename, "w") as f:
1649             for e in error_test:
1650                 f.write(f"{e}\n")
1651
1652
1653 ######################################################################