Update.
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 import math, os, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def masked_inplace_autoregression(
14     model,
15     batch_size,
16     input,
17     ar_mask,
18     deterministic_synthesis,
19     forbidden_tokens=None,
20     progress_bar_desc="autoregression",
21     device=torch.device("cpu"),
22 ):
23     assert input.size() == ar_mask.size()
24
25     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
26
27     if progress_bar_desc is not None:
28         batches = tqdm.tqdm(
29             batches,
30             dynamic_ncols=True,
31             desc=progress_bar_desc,
32             # total=input.size(0) // batch_size,
33         )
34
35     with torch.autograd.no_grad():
36         t = model.training
37         model.eval()
38
39         for input, ar_mask in batches:
40             model.masked_inplace_autoregression(
41                 input, ar_mask, forbidden_tokens, deterministic_synthesis
42             )
43
44         model.train(t)
45
46
47 ######################################################################
48
49
50 class Task:
51     def batches(self, split="train"):
52         pass
53
54     def vocabulary_size(self):
55         pass
56
57     def produce_results(
58         self, n_epoch, model, result_dir, logger, deterministic_synthesis
59     ):
60         pass
61
62
63 ######################################################################
64
65
66 class Problem:
67     def generate(nb):
68         pass
69
70     def perf(seq, logger):
71         pass
72
73
74 class ProblemByheart(Problem):
75     def __init__(self):
76         pass
77
78
79 class SandBox(Task):
80     def __init__(
81         self,
82         nb_train_samples,
83         nb_test_samples,
84         batch_size,
85         logger=None,
86         device=torch.device("cpu"),
87     ):
88         super().__init__()
89
90         self.batch_size = batch_size
91
92         def generate_sequences(nb_samples):
93             problem_indexes = torch.randint(len(problems), (nb_samples,))
94             nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
95             print(f"{nb_samples_per_problem}")
96
97         self.train_input = generate_sequences(nb_train_samples)
98         self.test_input = generate_sequences(nb_test_samples)
99
100         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
101
102     def batches(self, split="train", nb_to_use=-1, desc=None):
103         assert split in {"train", "test"}
104         input = self.train_input if split == "train" else self.test_input
105         if nb_to_use > 0:
106             input = input[:nb_to_use]
107         if desc is None:
108             desc = f"epoch-{split}"
109         for batch in tqdm.tqdm(
110             input.split(self.batch_size), dynamic_ncols=True, desc=desc
111         ):
112             yield batch
113
114     def vocabulary_size(self):
115         return self.nb_codes
116
117     def produce_results(
118         self, n_epoch, model, result_dir, logger, deterministic_synthesis
119     ):
120         # logger(
121         # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
122         # )
123         pass
124
125
126 ######################################################################
127
128 import picoclvr
129
130
131 class PicoCLVR(Task):
132     # Make a tensor from a list of strings
133     def tensorize(self, descr):
134         token_descr = [s.strip().split(" ") for s in descr]
135         l = max([len(s) for s in token_descr])
136         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
137         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
138         return torch.tensor(id_descr, device=self.device)
139
140     # Make a list of strings from a tensor
141     def detensorize(self, x):
142         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
143
144     # trim all the tensors in the tuple z to remove as much token from
145     # left and right in the first tensor. If z is a tuple, all its
146     # elements are trimed according to the triming for the first
147     def trim(self, z, token="<nul>"):
148         n = self.token2id[token]
149         if type(z) == tuple:
150             x = z[0]
151             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
152             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
153             return tuple([t[:, a:b] for t in z])
154         else:
155             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
156             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
157             return z[:, a:b]
158
159     ######################
160
161     def __init__(
162         self,
163         nb_train_samples,
164         nb_test_samples,
165         batch_size,
166         height,
167         width,
168         nb_colors=5,
169         logger=None,
170         device=torch.device("cpu"),
171         pruner_train=None,
172         pruner_eval=None,
173     ):
174         super().__init__()
175
176         def generate_descr(nb, cache_suffix, pruner):
177             return picoclvr.generate(
178                 nb,
179                 height=self.height,
180                 width=self.width,
181                 nb_colors=nb_colors,
182                 pruner=pruner,
183             )
184
185         self.height = height
186         self.width = width
187         self.batch_size = batch_size
188         self.device = device
189         self.pruner_train = pruner_train
190         self.pruner_eval = pruner_eval
191
192         if logger is not None:
193             logger(
194                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
195             )
196
197         self.train_descr = generate_descr(
198             nb_train_samples, "train", pruner=self.pruner_train
199         )
200         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
201
202         # Build the tokenizer
203         tokens = {"<nul>", "<img>"}
204         for d in [self.train_descr, self.test_descr]:
205             for s in d:
206                 for t in s.strip().split(" "):
207                     tokens.add(t)
208         # make this set a sorted list to get the same tensors given
209         # the same descr
210         tokens = list(tokens)
211         tokens.sort()
212         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
213         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
214         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
215
216         # Tokenize the train and test sets
217         self.train_input = self.tensorize(self.train_descr)
218         self.test_input = self.tensorize(self.test_descr)
219
220     def batches(self, split="train"):
221         assert split in {"train", "test"}
222         input = self.train_input if split == "train" else self.test_input
223         for batch in tqdm.tqdm(
224             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
225         ):
226             yield self.trim(batch)
227
228     def vocabulary_size(self):
229         return len(self.token2id)
230
231     def compute_missing_properties(
232         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
233     ):
234         acc_nb_requested_properties = []
235         acc_nb_missing_properties = []
236         acc_nb_results = 0
237
238         for input in tqdm.tqdm(
239             self.test_input.split(self.batch_size),
240             dynamic_ncols=True,
241             desc=f"test-properties",
242         ):
243             result = input.clone()
244             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
245             result = (1 - ar_mask) * result + ar_mask * self.t_nul
246             masked_inplace_autoregression(
247                 model,
248                 self.batch_size,
249                 result,
250                 ar_mask,
251                 deterministic_synthesis,
252                 progress_bar_desc=None,
253                 device=self.device,
254             )
255
256             result_descr = self.detensorize(result)
257             np = picoclvr.nb_properties(
258                 result_descr,
259                 height=self.height,
260                 width=self.width,
261                 pruner=pruner,
262             )
263             nb_requested_properties, _, nb_missing_properties = zip(*np)
264             acc_nb_requested_properties += nb_requested_properties
265             acc_nb_missing_properties += nb_missing_properties
266             acc_nb_results += len(result_descr)
267
268         nb_requested_properties = sum(acc_nb_requested_properties)
269         nb_missing_properties = sum(acc_nb_missing_properties)
270
271         prefix = "" if pruner is None else "pruned_"
272         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
273         logger(
274             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
275         )
276         logger(
277             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
278         )
279
280     ######################################################################
281
282     def produce_results(
283         self, n_epoch, model, result_dir, logger, deterministic_synthesis
284     ):
285         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
286
287         if self.pruner_eval is not None:
288             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
289
290         nb_tokens_to_generate = self.height * self.width + 3
291         result_descr = []
292         nb_per_primer = 8
293         primer = []
294
295         for primer_descr in [
296             "red above green <sep> green top <sep> blue right of red",
297             "there is red <sep> there is yellow <sep> there is blue",
298             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
299             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
300         ]:
301             primer += [primer_descr + " <img>"] * nb_per_primer
302
303         result = self.tensorize(primer)
304         fill = result.new_full(
305             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
306         )
307         result = torch.cat((result, fill), 1)
308         ar_mask = (result == self.t_nul).long()
309         masked_inplace_autoregression(
310             model,
311             self.batch_size,
312             result,
313             ar_mask,
314             deterministic_synthesis,
315             device=self.device,
316         )
317         result_descr = self.detensorize(result)
318
319         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
320
321         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
322         acc_nb_results = len(result_descr)
323
324         nb_requested_properties = sum(acc_nb_requested_properties)
325         nb_missing_properties = sum(acc_nb_missing_properties)
326
327         prefix = "demo_"
328         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
329         logger(
330             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
331         )
332         logger(
333             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
334         )
335
336         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
337
338         if img.dim() == 5:
339             if img.size(1) == 1:
340                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
341             else:
342                 img = torch.cat(
343                     [
344                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
345                         for x in img
346                     ],
347                     0,
348                 )
349
350         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
351         torchvision.utils.save_image(
352             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
353         )
354         logger(f"wrote {image_name}")
355
356
357 ######################################################################
358
359
360 class MNIST(Task):
361     def __init__(
362         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
363     ):
364         super().__init__()
365
366         self.nb_train_samples = (nb_train_samples,)
367         self.nb_test_samples = (nb_test_samples,)
368         self.batch_size = batch_size
369         self.device = device
370         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
371         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
372         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
373         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
374
375     def batches(self, split="train", nb_to_use=-1, desc=None):
376         assert split in {"train", "test"}
377         input = self.train_input if split == "train" else self.test_input
378         if nb_to_use > 0:
379             input = input[:nb_to_use]
380         if desc is None:
381             desc = f"epoch-{split}"
382         for batch in tqdm.tqdm(
383             input.split(self.batch_size), dynamic_ncols=True, desc=desc
384         ):
385             yield batch
386
387     def vocabulary_size(self):
388         return 256
389
390     def produce_results(
391         self, n_epoch, model, result_dir, logger, deterministic_synthesis
392     ):
393         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
394         ar_mask = torch.full_like(results, 1)
395         masked_inplace_autoregression(
396             model,
397             self.batch_size,
398             results,
399             ar_mask,
400             deterministic_synthesis,
401             device=self.device,
402         )
403         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
404         torchvision.utils.save_image(
405             1 - results.reshape(-1, 1, 28, 28) / 255.0,
406             image_name,
407             nrow=16,
408             pad_value=0.8,
409         )
410         logger(f"wrote {image_name}")
411
412
413 ######################################################################
414
415 import maze
416
417
418 class Maze(Task):
419     def map2seq(self, *m):
420         return torch.cat([x.flatten(1) for x in m], 1)
421
422     def seq2map(self, s):
423         s = s.reshape(s.size(0), -1, self.height, self.width)
424         return (s[:, k] for k in range(s.size(1)))
425
426     def __init__(
427         self,
428         nb_train_samples,
429         nb_test_samples,
430         batch_size,
431         height,
432         width,
433         nb_walls,
434         device=torch.device("cpu"),
435     ):
436         super().__init__()
437
438         self.batch_size = batch_size
439         self.height = height
440         self.width = width
441         self.device = device
442
443         train_mazes, train_paths, _ = maze.create_maze_data(
444             nb_train_samples,
445             height=height,
446             width=width,
447             nb_walls=nb_walls,
448             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
449         )
450         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
451
452         test_mazes, test_paths, _ = maze.create_maze_data(
453             nb_test_samples,
454             height=height,
455             width=width,
456             nb_walls=nb_walls,
457             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
458         )
459         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
460
461         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
462
463     def batches(self, split="train", nb_to_use=-1, desc=None):
464         assert split in {"train", "test"}
465         input = self.train_input if split == "train" else self.test_input
466         if nb_to_use > 0:
467             input = input[:nb_to_use]
468         if desc is None:
469             desc = f"epoch-{split}"
470         for batch in tqdm.tqdm(
471             input.split(self.batch_size), dynamic_ncols=True, desc=desc
472         ):
473             yield batch
474
475     def vocabulary_size(self):
476         return self.nb_codes
477
478     def compute_error(
479         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
480     ):
481         nb_total, nb_correct = 0, 0
482         count = torch.zeros(
483             self.width * self.height,
484             self.width * self.height,
485             device=self.device,
486             dtype=torch.int64,
487         )
488
489         for input in self.batches(split, nb_to_use):
490             result = input.clone()
491             ar_mask = result.new_zeros(result.size())
492             ar_mask[:, self.height * self.width :] = 1
493             result *= 1 - ar_mask
494             masked_inplace_autoregression(
495                 model,
496                 self.batch_size,
497                 result,
498                 ar_mask,
499                 deterministic_synthesis,
500                 progress_bar_desc=None,
501                 device=self.device,
502             )
503             mazes, paths = self.seq2map(result)
504             path_correctness = maze.path_correctness(mazes, paths)
505             nb_correct += path_correctness.long().sum()
506             nb_total += mazes.size(0)
507
508             optimal_path_lengths = (
509                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
510             )
511             predicted_path_lengths = (
512                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
513             )
514             optimal_path_lengths = optimal_path_lengths[path_correctness]
515             predicted_path_lengths = predicted_path_lengths[path_correctness]
516             count[optimal_path_lengths, predicted_path_lengths] += 1
517
518         if count.max() == 0:
519             count = None
520         else:
521             count = count[
522                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
523             ]
524
525         return nb_total, nb_correct, count
526
527     def produce_results(
528         self, n_epoch, model, result_dir, logger, deterministic_synthesis
529     ):
530         train_nb_total, train_nb_correct, count = self.compute_error(
531             model,
532             "train",
533             nb_to_use=1000,
534             deterministic_synthesis=deterministic_synthesis,
535         )
536         logger(
537             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
538         )
539
540         test_nb_total, test_nb_correct, count = self.compute_error(
541             model,
542             "test",
543             nb_to_use=1000,
544             deterministic_synthesis=deterministic_synthesis,
545         )
546         logger(
547             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
548         )
549
550         if count is not None:
551             proportion_optimal = count.diagonal().sum().float() / count.sum()
552             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
553             with open(
554                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
555             ) as f:
556                 for i in range(count.size(0)):
557                     for j in range(count.size(1)):
558                         eol = " " if j < count.size(1) - 1 else "\n"
559                         f.write(f"{count[i,j]}{eol}")
560
561         input = self.test_input[:48]
562         result = input.clone()
563         ar_mask = result.new_zeros(result.size())
564         ar_mask[:, self.height * self.width :] = 1
565         result *= 1 - ar_mask
566         masked_inplace_autoregression(
567             model,
568             self.batch_size,
569             result,
570             ar_mask,
571             deterministic_synthesis,
572             device=self.device,
573         )
574
575         mazes, paths = self.seq2map(input)
576         _, predicted_paths = self.seq2map(result)
577
578         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
579         maze.save_image(
580             filename,
581             mazes=mazes,
582             target_paths=paths,
583             predicted_paths=predicted_paths,
584             path_correct=maze.path_correctness(mazes, predicted_paths),
585             path_optimal=maze.path_optimality(paths, predicted_paths),
586         )
587         logger(f"wrote {filename}")
588
589
590 ######################################################################
591
592
593 import snake
594
595
596 class Snake(Task):
597     def __init__(
598         self,
599         nb_train_samples,
600         nb_test_samples,
601         batch_size,
602         height,
603         width,
604         nb_colors,
605         length,
606         prompt_length,
607         device=torch.device("cpu"),
608     ):
609         super().__init__()
610
611         self.batch_size = batch_size
612         self.height = height
613         self.width = width
614         self.device = device
615         self.prompt_length = prompt_length
616
617         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
618             nb_train_samples,
619             height,
620             width,
621             nb_colors,
622             length,
623             prompt_length,
624             self.device,
625         )
626         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
627             nb_test_samples,
628             height,
629             width,
630             nb_colors,
631             length,
632             prompt_length,
633             self.device,
634         )
635
636         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
637
638     def batches(self, split="train", nb_to_use=-1, desc=None):
639         assert split in {"train", "test"}
640         input = self.train_input if split == "train" else self.test_input
641         if nb_to_use > 0:
642             input = input[:nb_to_use]
643         if desc is None:
644             desc = f"epoch-{split}"
645         for batch in tqdm.tqdm(
646             input.split(self.batch_size), dynamic_ncols=True, desc=desc
647         ):
648             yield batch
649
650     def vocabulary_size(self):
651         return self.nb_codes
652
653     def produce_results(
654         self, n_epoch, model, result_dir, logger, deterministic_synthesis
655     ):
656         def compute_nb_correct(input, prior_visits):
657             result = input.clone()
658             i = torch.arange(result.size(1), device=result.device)[None, :]
659             ar_mask = (
660                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
661                 .long()
662                 .expand_as(result)
663             )
664             result *= 1 - ar_mask
665
666             masked_inplace_autoregression(
667                 model,
668                 self.batch_size,
669                 result,
670                 ar_mask,
671                 deterministic_synthesis,
672                 device=self.device,
673             )
674
675             nb_total = ((prior_visits > 0) * ar_mask).sum()
676
677             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
678
679             return nb_total, nb_correct
680
681         test_nb_total, test_nb_correct = compute_nb_correct(
682             self.test_input[:1000], self.test_prior_visits[:1000]
683         )
684
685         logger(
686             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
687         )
688
689
690 ######################################################################
691
692
693 import stack
694
695
696 class Stack(Task):
697     def __init__(
698         self,
699         nb_train_samples,
700         nb_test_samples,
701         batch_size,
702         logger,
703         nb_steps,
704         nb_stacks,
705         nb_digits,
706         fraction_values_for_train=None,
707         device=torch.device("cpu"),
708     ):
709         super().__init__()
710
711         self.batch_size = batch_size
712         self.nb_steps = nb_steps
713         self.nb_stacks = nb_stacks
714         self.nb_digits = nb_digits
715         self.device = device
716
717         if fraction_values_for_train is None:
718             values_for_train = None
719             values_for_test = None
720         else:
721             all = torch.randperm(10**nb_digits)
722             nb_for_train = int(all.size(0) * fraction_values_for_train)
723             values_for_train = all[:nb_for_train]
724             values_for_test = all[nb_for_train:]
725
726         self.train_input, self.train_stack_counts = stack.generate_sequences(
727             nb_train_samples,
728             nb_steps,
729             nb_stacks,
730             nb_digits,
731             values_for_train,
732             self.device,
733         )
734
735         self.test_input, self.test_stack_counts = stack.generate_sequences(
736             nb_test_samples,
737             nb_steps,
738             nb_stacks,
739             nb_digits,
740             values_for_test,
741             self.device,
742         )
743
744         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
745         counts = self.test_stack_counts.flatten()[i.flatten()]
746         counts = F.one_hot(counts).sum(0)
747         logger(f"test_pop_stack_counts {counts}")
748
749         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
750
751     def batches(self, split="train", nb_to_use=-1, desc=None):
752         assert split in {"train", "test"}
753         input = self.train_input if split == "train" else self.test_input
754         if nb_to_use > 0:
755             input = input[:nb_to_use]
756         if desc is None:
757             desc = f"epoch-{split}"
758         for batch in tqdm.tqdm(
759             input.split(self.batch_size), dynamic_ncols=True, desc=desc
760         ):
761             yield batch
762
763     def vocabulary_size(self):
764         return self.nb_codes
765
766     def produce_results(
767         self, n_epoch, model, result_dir, logger, deterministic_synthesis
768     ):
769         def compute_nb_correct(input):
770             result = input.clone()
771             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
772             ar_mask = (result != input).long()
773             masked_inplace_autoregression(
774                 model,
775                 self.batch_size,
776                 result,
777                 ar_mask,
778                 deterministic_synthesis,
779                 device=self.device,
780             )
781
782             errors = ((result != input).long() * ar_mask).reshape(
783                 -1, 1 + self.nb_digits
784             )
785             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
786
787             nb_total = ar_mask.max(1).values.sum()
788             nb_correct = nb_total - errors.max(1).values.sum()
789
790             return nb_total, nb_correct
791
792         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
793
794         logger(
795             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
796         )
797
798         ##############################################################
799         # Log a few generated sequences
800         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
801         result = input.clone()
802         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
803         ar_mask = (result != input).long()
804
805         # for n in range(result.size(0)):
806         # logger(
807         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
808         # )
809
810         masked_inplace_autoregression(
811             model,
812             self.batch_size,
813             result,
814             ar_mask,
815             deterministic_synthesis,
816             device=self.device,
817         )
818
819         for n in range(result.size(0)):
820             logger(
821                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
822             )
823         ##############################################################
824
825
826 ######################################################################
827
828
829 import expr
830
831
832 class Expr(Task):
833     def tensorize(self, sequences):
834         len_max = max([len(x) for x in sequences])
835         return torch.cat(
836             [
837                 torch.tensor(
838                     [
839                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
840                         for s in sequences
841                     ]
842                 )
843             ],
844             0,
845         ).to(self.device)
846
847     def __init__(
848         self,
849         nb_train_samples,
850         nb_test_samples,
851         nb_variables,
852         sequence_length,
853         operand_max,
854         result_max,
855         batch_size,
856         device=torch.device("cpu"),
857     ):
858         super().__init__()
859
860         self.batch_size = batch_size
861         self.device = device
862
863         train_sequences = expr.generate_sequences(
864             nb_train_samples,
865             nb_variables=nb_variables,
866             length=sequence_length,
867             operand_max=operand_max,
868             result_max=result_max,
869         )
870
871         test_sequences = expr.generate_sequences(
872             nb_test_samples,
873             nb_variables=nb_variables,
874             length=sequence_length,
875             operand_max=operand_max,
876             result_max=result_max,
877         )
878
879         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
880         symbols.sort()
881
882         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
883         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
884
885         self.filler, self.space = self.char2id["#"], self.char2id[" "]
886
887         self.train_input = self.tensorize(train_sequences)
888         self.test_input = self.tensorize(test_sequences)
889
890         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
891
892     def batches(self, split="train", nb_to_use=-1, desc=None):
893         assert split in {"train", "test"}
894         input = self.train_input if split == "train" else self.test_input
895         if nb_to_use > 0:
896             input = input[:nb_to_use]
897         if desc is None:
898             desc = f"epoch-{split}"
899         for batch in tqdm.tqdm(
900             input.split(self.batch_size), dynamic_ncols=True, desc=desc
901         ):
902             last = (batch != self.filler).max(0).values.nonzero().max() + 3
903             batch = batch[:, :last]
904             yield batch
905
906     def vocabulary_size(self):
907         return self.nb_codes
908
909     def seq2str(self, s):
910         return "".join([self.id2char[k.item()] for k in s])
911
912     def produce_results(
913         self,
914         n_epoch,
915         model,
916         result_dir,
917         logger,
918         deterministic_synthesis,
919         input_file=None,
920     ):
921         def compute_nb_correct(input):
922             result = input.clone()
923             s = (result == self.space).long()
924             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
925             result = (1 - ar_mask) * result + ar_mask * self.filler
926             masked_inplace_autoregression(
927                 model,
928                 self.batch_size,
929                 result,
930                 ar_mask,
931                 deterministic_synthesis,
932                 device=self.device,
933             )
934
935             nb_total = input.size(0)
936             nb_correct = (input == result).long().min(1).values.sum()
937
938             #######################################################################
939             # Comput predicted vs. true variable values
940
941             nb_delta = torch.zeros(5, dtype=torch.int64)
942             nb_missed = 0
943
944             values_input = expr.extract_results([self.seq2str(s) for s in input])
945             values_result = expr.extract_results([self.seq2str(s) for s in result])
946
947             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
948
949             with open(filename, "w") as f:
950                 for i, r in zip(values_input, values_result):
951                     for n, vi in i.items():
952                         vr = r.get(n)
953                         f.write(f"{vi} {-1 if vr is None else vr}\n")
954
955                         if vr is None or vr < 0:
956                             nb_missed += 1
957                         else:
958                             d = abs(vr - vi)
959                             if d >= nb_delta.size(0):
960                                 nb_missed += 1
961                             else:
962                                 nb_delta[d] += 1
963
964             ######################################################################
965
966             return nb_total, nb_correct, nb_delta, nb_missed
967
968         (
969             test_nb_total,
970             test_nb_correct,
971             test_nb_delta,
972             test_nb_missed,
973         ) = compute_nb_correct(self.test_input[:10000])
974
975         logger(
976             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
977         )
978
979         nb_total = test_nb_delta.sum() + test_nb_missed
980         for d in range(test_nb_delta.size(0)):
981             logger(
982                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
983             )
984         logger(
985             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
986         )
987
988         ##############################################################
989         # Log a few generated sequences
990         if input_file is None:
991             input = self.test_input[:10]
992         else:
993             with open(input_file, "r") as f:
994                 sequences = [e.strip() for e in f.readlines()]
995                 sequences = [s + " " + "#" * 50 for s in sequences]
996                 input = self.tensorize(sequences)
997
998         result = input.clone()
999         s = (result == self.space).long()
1000         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1001         result = (1 - ar_mask) * result + ar_mask * self.filler
1002
1003         for n in range(result.size(0)):
1004             logger(f"test_before {self.seq2str(result[n])}")
1005
1006         masked_inplace_autoregression(
1007             model,
1008             self.batch_size,
1009             result,
1010             ar_mask,
1011             deterministic_synthesis,
1012             device=self.device,
1013         )
1014
1015         correct = (1 - ar_mask) * self.space + ar_mask * input
1016         for n in range(result.size(0)):
1017             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1018             logger(f"test_after  {self.seq2str(result[n])} {comment}")
1019             logger(f"truth       {self.seq2str(correct[n])}")
1020         ##############################################################
1021
1022
1023 ######################################################################
1024
1025 import world
1026
1027
1028 class World(Task):
1029     def __init__(
1030         self,
1031         nb_train_samples,
1032         nb_test_samples,
1033         batch_size,
1034         vqae_nb_epochs,
1035         logger=None,
1036         device=torch.device("cpu"),
1037         device_storage=torch.device("cpu"),
1038     ):
1039         super().__init__()
1040
1041         self.batch_size = batch_size
1042         self.device = device
1043
1044         (
1045             train_frames,
1046             train_action_seq,
1047             test_frames,
1048             test_action_seq,
1049             self.frame2seq,
1050             self.seq2frame,
1051         ) = world.create_data_and_processors(
1052             nb_train_samples,
1053             nb_test_samples,
1054             mode="first_last",
1055             nb_steps=30,
1056             nb_epochs=vqae_nb_epochs,
1057             logger=logger,
1058             device=device,
1059             device_storage=device_storage,
1060         )
1061
1062         print(f"{train_action_seq.size()=}")
1063
1064         train_frame_seq = self.frame2seq(train_frames).to(device_storage)
1065         test_frame_seq = self.frame2seq(test_frames).to(device_storage)
1066
1067         nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
1068         nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
1069
1070         self.len_frame_seq = train_frame_seq.size(1)
1071         self.len_action_seq = train_action_seq.size(1)
1072         self.nb_codes = nb_frame_codes + nb_action_codes
1073
1074         train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
1075         print(f"{train_action_seq.device=} {nb_frame_codes.device=}")
1076         train_action_seq += nb_frame_codes
1077         self.train_input = torch.cat(
1078             (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
1079         )
1080
1081         test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
1082         test_action_seq += nb_frame_codes
1083         self.test_input = torch.cat(
1084             (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
1085         )
1086
1087     def batches(self, split="train", nb_to_use=-1, desc=None):
1088         assert split in {"train", "test"}
1089         input = self.train_input if split == "train" else self.test_input
1090         if nb_to_use > 0:
1091             input = input[:nb_to_use]
1092         if desc is None:
1093             desc = f"epoch-{split}"
1094         for batch in tqdm.tqdm(
1095             input.split(self.batch_size), dynamic_ncols=True, desc=desc
1096         ):
1097             yield batch.to(self.device)
1098
1099     def vocabulary_size(self):
1100         return self.nb_codes
1101
1102     def produce_results(
1103         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1104     ):
1105         k = torch.arange(
1106             2 * self.len_frame_seq + self.len_action_seq, device=self.device
1107         )[None, :]
1108
1109         input = self.test_input[:64].to(self.device)
1110         result = input.clone()
1111
1112         ar_mask = (
1113             (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
1114         )
1115         result *= 1 - ar_mask
1116
1117         masked_inplace_autoregression(
1118             model,
1119             self.batch_size,
1120             result,
1121             ar_mask,
1122             deterministic_synthesis,
1123             device=self.device,
1124         )
1125
1126         seq_start = input[:, : self.len_frame_seq]
1127         seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
1128         seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
1129
1130         result = torch.cat(
1131             (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
1132         )
1133         result = result.reshape(-1, result.size(-1))
1134         print(f"{result.size()=}")
1135
1136         frames = self.seq2frame(result)
1137         image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
1138         torchvision.utils.save_image(
1139             frames.float() / (world.Box.nb_rgb_levels - 1),
1140             image_name,
1141             nrow=12,
1142             padding=1,
1143             pad_value=0.0,
1144         )
1145         logger(f"wrote {image_name}")
1146
1147
1148 ######################################################################