96d062169467ec6b2976e5c5a39e0c9a78537b9f
[picoclvr.git] / tasks.py
1 #!/usr/bin/env python
2
3 import math, os, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 def masked_inplace_autoregression(
14     model,
15     batch_size,
16     input,
17     ar_mask,
18     deterministic_synthesis,
19     forbidden_tokens=None,
20     progress_bar_desc="autoregression",
21     device=torch.device("cpu"),
22 ):
23     assert input.size() == ar_mask.size()
24
25     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
26
27     if progress_bar_desc is not None:
28         batches = tqdm.tqdm(
29             batches,
30             dynamic_ncols=True,
31             desc=progress_bar_desc,
32             #total=input.size(0) // batch_size,
33         )
34
35     with torch.autograd.no_grad():
36         t = model.training
37         model.eval()
38
39         for input, ar_mask in batches:
40             model.masked_inplace_autoregression(
41                 input, ar_mask, forbidden_tokens, deterministic_synthesis
42             )
43
44         model.train(t)
45
46
47 ######################################################################
48
49
50 class Task:
51     def batches(self, split="train"):
52         pass
53
54     def vocabulary_size(self):
55         pass
56
57     def produce_results(
58         self, n_epoch, model, result_dir, logger, deterministic_synthesis
59     ):
60         pass
61
62
63 ######################################################################
64
65 import picoclvr
66
67
68 class PicoCLVR(Task):
69     # Make a tensor from a list of strings
70     def tensorize(self, descr):
71         token_descr = [s.strip().split(" ") for s in descr]
72         l = max([len(s) for s in token_descr])
73         token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
74         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
75         return torch.tensor(id_descr, device=self.device)
76
77     # Make a list of strings from a tensor
78     def detensorize(self, x):
79         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
80
81     # trim all the tensors in the tuple z to remove as much token from
82     # left and right in the first tensor. If z is a tuple, all its
83     # elements are trimed according to the triming for the first
84     def trim(self, z, token="<nul>"):
85         n = self.token2id[token]
86         if type(z) == tuple:
87             x = z[0]
88             i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
89             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
90             return tuple([t[:, a:b] for t in z])
91         else:
92             i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
93             a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
94             return z[:, a:b]
95
96     ######################
97
98     def __init__(
99         self,
100         nb_train_samples,
101         nb_test_samples,
102         batch_size,
103         height,
104         width,
105         nb_colors=5,
106         logger=None,
107         device=torch.device("cpu"),
108         pruner_train=None,
109         pruner_eval=None,
110     ):
111         def generate_descr(nb, cache_suffix, pruner):
112             return picoclvr.generate(
113                 nb,
114                 height=self.height,
115                 width=self.width,
116                 nb_colors=nb_colors,
117                 pruner=pruner,
118             )
119
120         self.height = height
121         self.width = width
122         self.batch_size = batch_size
123         self.device = device
124         self.pruner_train = pruner_train
125         self.pruner_eval = pruner_eval
126
127         if logger is not None:
128             logger(
129                 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
130             )
131
132         self.train_descr = generate_descr(
133             nb_train_samples, "train", pruner=self.pruner_train
134         )
135         self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
136
137         # Build the tokenizer
138         tokens = {"<nul>", "<img>"}
139         for d in [self.train_descr, self.test_descr]:
140             for s in d:
141                 for t in s.strip().split(" "):
142                     tokens.add(t)
143         # make this set a sorted list to get the same tensors given
144         # the same descr
145         tokens = list(tokens)
146         tokens.sort()
147         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
148         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
149         self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
150
151         # Tokenize the train and test sets
152         self.train_input = self.tensorize(self.train_descr)
153         self.test_input = self.tensorize(self.test_descr)
154
155     def batches(self, split="train"):
156         assert split in {"train", "test"}
157         input = self.train_input if split == "train" else self.test_input
158         for batch in tqdm.tqdm(
159             input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
160         ):
161             yield self.trim(batch)
162
163     def vocabulary_size(self):
164         return len(self.token2id)
165
166     def compute_missing_properties(
167         self, n_epoch, model, logger, deterministic_synthesis, pruner=None
168     ):
169         acc_nb_requested_properties = []
170         acc_nb_missing_properties = []
171         acc_nb_results = 0
172
173         for input in tqdm.tqdm(
174             self.test_input.split(self.batch_size),
175             dynamic_ncols=True,
176             desc=f"test-properties",
177         ):
178             result = input.clone()
179             ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
180             result = (1 - ar_mask) * result + ar_mask * self.t_nul
181             masked_inplace_autoregression(
182                 model,
183                 self.batch_size,
184                 result,
185                 ar_mask,
186                 deterministic_synthesis,
187                 progress_bar_desc=None,
188                 device=self.device,
189             )
190
191             result_descr = self.detensorize(result)
192             np = picoclvr.nb_properties(
193                 result_descr,
194                 height=self.height,
195                 width=self.width,
196                 pruner=pruner,
197             )
198             nb_requested_properties, _, nb_missing_properties = zip(*np)
199             acc_nb_requested_properties += nb_requested_properties
200             acc_nb_missing_properties += nb_missing_properties
201             acc_nb_results += len(result_descr)
202
203         nb_requested_properties = sum(acc_nb_requested_properties)
204         nb_missing_properties = sum(acc_nb_missing_properties)
205
206         prefix = "" if pruner is None else "pruned_"
207         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
208         logger(
209             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
210         )
211         logger(
212             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
213         )
214
215     ######################################################################
216
217     def produce_results(
218         self, n_epoch, model, result_dir, logger, deterministic_synthesis
219     ):
220         self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
221
222         if self.pruner_eval is not None:
223             self.compute_missing_properties(n_epoch, model, self.pruner_eval)
224
225         nb_tokens_to_generate = self.height * self.width + 3
226         result_descr = []
227         nb_per_primer = 8
228         primer = []
229
230         for primer_descr in [
231             "red above green <sep> green top <sep> blue right of red",
232             "there is red <sep> there is yellow <sep> there is blue",
233             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
234             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
235         ]:
236             primer += [primer_descr + " <img>"] * nb_per_primer
237
238         result = self.tensorize(primer)
239         fill = result.new_full(
240             result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
241         )
242         result = torch.cat((result, fill), 1)
243         ar_mask = (result == self.t_nul).long()
244         masked_inplace_autoregression(
245             model,
246             self.batch_size,
247             result,
248             ar_mask,
249             deterministic_synthesis,
250             device=self.device,
251         )
252         result_descr = self.detensorize(result)
253
254         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
255
256         acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
257         acc_nb_results = len(result_descr)
258
259         nb_requested_properties = sum(acc_nb_requested_properties)
260         nb_missing_properties = sum(acc_nb_missing_properties)
261
262         prefix = "demo_"
263         logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
264         logger(
265             f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
266         )
267         logger(
268             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
269         )
270
271         img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
272
273         if img.dim() == 5:
274             if img.size(1) == 1:
275                 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
276             else:
277                 img = torch.cat(
278                     [
279                         torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
280                         for x in img
281                     ],
282                     0,
283                 )
284
285         image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
286         torchvision.utils.save_image(
287             img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
288         )
289         logger(f"wrote {image_name}")
290
291
292 ######################################################################
293
294
295 class MNIST(Task):
296     def __init__(
297         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
298     ):
299         self.nb_train_samples = (nb_train_samples,)
300         self.nb_test_samples = (nb_test_samples,)
301         self.batch_size = batch_size
302         self.device = device
303         data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
304         self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
305         data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
306         self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
307
308     def batches(self, split="train", nb_to_use=-1, desc=None):
309         assert split in {"train", "test"}
310         input = self.train_input if split == "train" else self.test_input
311         if nb_to_use > 0:
312             input = input[:nb_to_use]
313         if desc is None:
314             desc = f"epoch-{split}"
315         for batch in tqdm.tqdm(
316             input.split(self.batch_size), dynamic_ncols=True, desc=desc
317         ):
318             yield batch
319
320     def vocabulary_size(self):
321         return 256
322
323     def produce_results(
324         self, n_epoch, model, result_dir, logger, deterministic_synthesis
325     ):
326         results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
327         ar_mask = torch.full_like(results, 1)
328         masked_inplace_autoregression(
329             model,
330             self.batch_size,
331             results,
332             ar_mask,
333             deterministic_synthesis,
334             device=self.device,
335         )
336         image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
337         torchvision.utils.save_image(
338             1 - results.reshape(-1, 1, 28, 28) / 255.0,
339             image_name,
340             nrow=16,
341             pad_value=0.8,
342         )
343         logger(f"wrote {image_name}")
344
345
346 ######################################################################
347
348 import maze
349
350
351 class Maze(Task):
352     def map2seq(self, *m):
353         return torch.cat([x.flatten(1) for x in m], 1)
354
355     def seq2map(self, s):
356         s = s.reshape(s.size(0), -1, self.height, self.width)
357         return (s[:, k] for k in range(s.size(1)))
358
359     def __init__(
360         self,
361         nb_train_samples,
362         nb_test_samples,
363         batch_size,
364         height,
365         width,
366         nb_walls,
367         device=torch.device("cpu"),
368     ):
369         self.batch_size = batch_size
370         self.height = height
371         self.width = width
372         self.device = device
373
374         train_mazes, train_paths, _ = maze.create_maze_data(
375             nb_train_samples,
376             height=height,
377             width=width,
378             nb_walls=nb_walls,
379             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
380         )
381         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
382
383         test_mazes, test_paths, _ = maze.create_maze_data(
384             nb_test_samples,
385             height=height,
386             width=width,
387             nb_walls=nb_walls,
388             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
389         )
390         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
391
392         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
393
394     def batches(self, split="train", nb_to_use=-1, desc=None):
395         assert split in {"train", "test"}
396         input = self.train_input if split == "train" else self.test_input
397         if nb_to_use > 0:
398             input = input[:nb_to_use]
399         if desc is None:
400             desc = f"epoch-{split}"
401         for batch in tqdm.tqdm(
402             input.split(self.batch_size), dynamic_ncols=True, desc=desc
403         ):
404             yield batch
405
406     def vocabulary_size(self):
407         return self.nb_codes
408
409     def compute_error(
410         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
411     ):
412         nb_total, nb_correct = 0, 0
413         count = torch.zeros(
414             self.width * self.height,
415             self.width * self.height,
416             device=self.device,
417             dtype=torch.int64,
418         )
419
420         for input in self.batches(split, nb_to_use):
421             result = input.clone()
422             ar_mask = result.new_zeros(result.size())
423             ar_mask[:, self.height * self.width :] = 1
424             result *= 1 - ar_mask
425             masked_inplace_autoregression(
426                 model,
427                 self.batch_size,
428                 result,
429                 ar_mask,
430                 deterministic_synthesis,
431                 progress_bar_desc=None,
432                 device=self.device,
433             )
434             mazes, paths = self.seq2map(result)
435             path_correctness = maze.path_correctness(mazes, paths)
436             nb_correct += path_correctness.long().sum()
437             nb_total += mazes.size(0)
438
439             optimal_path_lengths = (
440                 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
441             )
442             predicted_path_lengths = (
443                 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
444             )
445             optimal_path_lengths = optimal_path_lengths[path_correctness]
446             predicted_path_lengths = predicted_path_lengths[path_correctness]
447             count[optimal_path_lengths, predicted_path_lengths] += 1
448
449         if count.max() == 0:
450             count = None
451         else:
452             count = count[
453                 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
454             ]
455
456         return nb_total, nb_correct, count
457
458     def produce_results(
459         self, n_epoch, model, result_dir, logger, deterministic_synthesis
460     ):
461         train_nb_total, train_nb_correct, count = self.compute_error(
462             model,
463             "train",
464             nb_to_use=1000,
465             deterministic_synthesis=deterministic_synthesis,
466         )
467         logger(
468             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
469         )
470
471         test_nb_total, test_nb_correct, count = self.compute_error(
472             model,
473             "test",
474             nb_to_use=1000,
475             deterministic_synthesis=deterministic_synthesis,
476         )
477         logger(
478             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
479         )
480
481         if count is not None:
482             proportion_optimal = count.diagonal().sum().float() / count.sum()
483             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
484             with open(
485                 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
486             ) as f:
487                 for i in range(count.size(0)):
488                     for j in range(count.size(1)):
489                         eol = " " if j < count.size(1) - 1 else "\n"
490                         f.write(f"{count[i,j]}{eol}")
491
492         input = self.test_input[:48]
493         result = input.clone()
494         ar_mask = result.new_zeros(result.size())
495         ar_mask[:, self.height * self.width :] = 1
496         result *= 1 - ar_mask
497         masked_inplace_autoregression(
498             model,
499             self.batch_size,
500             result,
501             ar_mask,
502             deterministic_synthesis,
503             device=self.device,
504         )
505
506         mazes, paths = self.seq2map(input)
507         _, predicted_paths = self.seq2map(result)
508
509         filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
510         maze.save_image(
511             filename,
512             mazes=mazes,
513             target_paths=paths,
514             predicted_paths=predicted_paths,
515             path_correct=maze.path_correctness(mazes, predicted_paths),
516             path_optimal=maze.path_optimality(paths, predicted_paths),
517         )
518         logger(f"wrote {filename}")
519
520
521 ######################################################################
522
523
524 import snake
525
526
527 class Snake(Task):
528     def __init__(
529         self,
530         nb_train_samples,
531         nb_test_samples,
532         batch_size,
533         height,
534         width,
535         nb_colors,
536         length,
537         prompt_length,
538         device=torch.device("cpu"),
539     ):
540         self.batch_size = batch_size
541         self.height = height
542         self.width = width
543         self.device = device
544         self.prompt_length = prompt_length
545
546         self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
547             nb_train_samples,
548             height,
549             width,
550             nb_colors,
551             length,
552             prompt_length,
553             self.device,
554         )
555         self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
556             nb_test_samples,
557             height,
558             width,
559             nb_colors,
560             length,
561             prompt_length,
562             self.device,
563         )
564
565         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
566
567     def batches(self, split="train", nb_to_use=-1, desc=None):
568         assert split in {"train", "test"}
569         input = self.train_input if split == "train" else self.test_input
570         if nb_to_use > 0:
571             input = input[:nb_to_use]
572         if desc is None:
573             desc = f"epoch-{split}"
574         for batch in tqdm.tqdm(
575             input.split(self.batch_size), dynamic_ncols=True, desc=desc
576         ):
577             yield batch
578
579     def vocabulary_size(self):
580         return self.nb_codes
581
582     def produce_results(
583         self, n_epoch, model, result_dir, logger, deterministic_synthesis
584     ):
585         def compute_nb_correct(input, prior_visits):
586             result = input.clone()
587             i = torch.arange(result.size(1), device=result.device)[None, :]
588             ar_mask = (
589                 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
590                 .long()
591                 .expand_as(result)
592             )
593             result *= 1 - ar_mask
594
595             masked_inplace_autoregression(
596                 model,
597                 self.batch_size,
598                 result,
599                 ar_mask,
600                 deterministic_synthesis,
601                 device=self.device,
602             )
603
604             nb_total = ((prior_visits > 0) * ar_mask).sum()
605
606             nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
607
608             return nb_total, nb_correct
609
610         test_nb_total, test_nb_correct = compute_nb_correct(
611             self.test_input[:1000], self.test_prior_visits[:1000]
612         )
613
614         logger(
615             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
616         )
617
618
619 ######################################################################
620
621
622 import stack
623
624
625 class Stack(Task):
626     def __init__(
627         self,
628         nb_train_samples,
629         nb_test_samples,
630         batch_size,
631         logger,
632         nb_steps,
633         nb_stacks,
634         nb_digits,
635         fraction_values_for_train=None,
636         device=torch.device("cpu"),
637     ):
638         self.batch_size = batch_size
639         self.nb_steps = nb_steps
640         self.nb_stacks = nb_stacks
641         self.nb_digits = nb_digits
642         self.device = device
643
644         if fraction_values_for_train is None:
645             values_for_train = None
646             values_for_test = None
647         else:
648             all = torch.randperm(10**nb_digits)
649             nb_for_train = int(all.size(0) * fraction_values_for_train)
650             values_for_train = all[:nb_for_train]
651             values_for_test = all[nb_for_train:]
652
653         self.train_input, self.train_stack_counts = stack.generate_sequences(
654             nb_train_samples,
655             nb_steps,
656             nb_stacks,
657             nb_digits,
658             values_for_train,
659             self.device,
660         )
661
662         self.test_input, self.test_stack_counts = stack.generate_sequences(
663             nb_test_samples,
664             nb_steps,
665             nb_stacks,
666             nb_digits,
667             values_for_test,
668             self.device,
669         )
670
671         i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
672         counts = self.test_stack_counts.flatten()[i.flatten()]
673         counts = F.one_hot(counts).sum(0)
674         logger(f"test_pop_stack_counts {counts}")
675
676         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
677
678     def batches(self, split="train", nb_to_use=-1, desc=None):
679         assert split in {"train", "test"}
680         input = self.train_input if split == "train" else self.test_input
681         if nb_to_use > 0:
682             input = input[:nb_to_use]
683         if desc is None:
684             desc = f"epoch-{split}"
685         for batch in tqdm.tqdm(
686             input.split(self.batch_size), dynamic_ncols=True, desc=desc
687         ):
688             yield batch
689
690     def vocabulary_size(self):
691         return self.nb_codes
692
693     def produce_results(
694         self, n_epoch, model, result_dir, logger, deterministic_synthesis
695     ):
696         def compute_nb_correct(input):
697             result = input.clone()
698             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
699             ar_mask = (result != input).long()
700             masked_inplace_autoregression(
701                 model,
702                 self.batch_size,
703                 result,
704                 ar_mask,
705                 deterministic_synthesis,
706                 device=self.device,
707             )
708
709             errors = ((result != input).long() * ar_mask).reshape(
710                 -1, 1 + self.nb_digits
711             )
712             ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
713
714             nb_total = ar_mask.max(1).values.sum()
715             nb_correct = nb_total - errors.max(1).values.sum()
716
717             return nb_total, nb_correct
718
719         test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
720
721         logger(
722             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
723         )
724
725         ##############################################################
726         # Log a few generated sequences
727         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
728         result = input.clone()
729         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
730         ar_mask = (result != input).long()
731
732         # for n in range(result.size(0)):
733         # logger(
734         # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
735         # )
736
737         masked_inplace_autoregression(
738             model,
739             self.batch_size,
740             result,
741             ar_mask,
742             deterministic_synthesis,
743             device=self.device,
744         )
745
746         for n in range(result.size(0)):
747             logger(
748                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
749             )
750         ##############################################################
751
752
753 ######################################################################
754
755
756 import expr
757
758
759 class Expr(Task):
760     def tensorize(self, sequences):
761         len_max = max([len(x) for x in sequences])
762         return torch.cat(
763             [
764                 torch.tensor(
765                     [
766                         [self.char2id[c] for c in s + "#" * (len_max - len(s))]
767                         for s in sequences
768                     ]
769                 )
770             ],
771             0,
772         ).to(self.device)
773
774     def __init__(
775         self,
776         nb_train_samples,
777         nb_test_samples,
778         nb_variables,
779         sequence_length,
780         operand_max,
781         result_max,
782         batch_size,
783         device=torch.device("cpu"),
784     ):
785         self.batch_size = batch_size
786         self.device = device
787
788         train_sequences = expr.generate_sequences(
789             nb_train_samples,
790             nb_variables=nb_variables,
791             length=sequence_length,
792             operand_max=operand_max,
793             result_max=result_max,
794         )
795
796         test_sequences = expr.generate_sequences(
797             nb_test_samples,
798             nb_variables=nb_variables,
799             length=sequence_length,
800             operand_max=operand_max,
801             result_max=result_max,
802         )
803
804         symbols = list(set("#" + "".join(train_sequences + test_sequences)))
805         symbols.sort()
806
807         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
808         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
809
810         self.filler, self.space = self.char2id["#"], self.char2id[" "]
811
812         self.train_input = self.tensorize(train_sequences)
813         self.test_input = self.tensorize(test_sequences)
814
815         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
816
817     def batches(self, split="train", nb_to_use=-1, desc=None):
818         assert split in {"train", "test"}
819         input = self.train_input if split == "train" else self.test_input
820         if nb_to_use > 0:
821             input = input[:nb_to_use]
822         if desc is None:
823             desc = f"epoch-{split}"
824         for batch in tqdm.tqdm(
825             input.split(self.batch_size), dynamic_ncols=True, desc=desc
826         ):
827             last = (batch != self.filler).max(0).values.nonzero().max() + 3
828             batch = batch[:, :last]
829             yield batch
830
831     def vocabulary_size(self):
832         return self.nb_codes
833
834     def seq2str(self, s):
835         return "".join([self.id2char[k.item()] for k in s])
836
837     def produce_results(
838         self,
839         n_epoch,
840         model,
841         result_dir,
842         logger,
843         deterministic_synthesis,
844         input_file=None,
845     ):
846         def compute_nb_correct(input):
847             result = input.clone()
848             s = (result == self.space).long()
849             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
850             result = (1 - ar_mask) * result + ar_mask * self.filler
851             masked_inplace_autoregression(
852                 model,
853                 self.batch_size,
854                 result,
855                 ar_mask,
856                 deterministic_synthesis,
857                 device=self.device,
858             )
859
860             nb_total = input.size(0)
861             nb_correct = (input == result).long().min(1).values.sum()
862
863             #######################################################################
864             # Comput predicted vs. true variable values
865
866             nb_delta = torch.zeros(5, dtype=torch.int64)
867             nb_missed = 0
868
869             values_input = expr.extract_results([self.seq2str(s) for s in input])
870             values_result = expr.extract_results([self.seq2str(s) for s in result])
871
872             filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
873
874             with open(filename, "w") as f:
875                 for i, r in zip(values_input, values_result):
876                     for n, vi in i.items():
877                         vr = r.get(n)
878                         f.write(f"{vi} {-1 if vr is None else vr}\n")
879
880                         if vr is None or vr < 0:
881                             nb_missed += 1
882                         else:
883                             d = abs(vr - vi)
884                             if d >= nb_delta.size(0):
885                                 nb_missed += 1
886                             else:
887                                 nb_delta[d] += 1
888
889             ######################################################################
890
891             return nb_total, nb_correct, nb_delta, nb_missed
892
893         (
894             test_nb_total,
895             test_nb_correct,
896             test_nb_delta,
897             test_nb_missed,
898         ) = compute_nb_correct(self.test_input[:10000])
899
900         logger(
901             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
902         )
903
904         nb_total = test_nb_delta.sum() + test_nb_missed
905         for d in range(test_nb_delta.size(0)):
906             logger(
907                 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
908             )
909         logger(
910             f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
911         )
912
913         ##############################################################
914         # Log a few generated sequences
915         if input_file is None:
916             input = self.test_input[:10]
917         else:
918             with open(input_file, "r") as f:
919                 sequences = [e.strip() for e in f.readlines()]
920                 sequences = [s + " " + "#" * 50 for s in sequences]
921                 input = self.tensorize(sequences)
922
923         result = input.clone()
924         s = (result == self.space).long()
925         ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
926         result = (1 - ar_mask) * result + ar_mask * self.filler
927
928         for n in range(result.size(0)):
929             logger(f"test_before {self.seq2str(result[n])}")
930
931         masked_inplace_autoregression(
932             model,
933             self.batch_size,
934             result,
935             ar_mask,
936             deterministic_synthesis,
937             device=self.device,
938         )
939
940         correct = (1 - ar_mask) * self.space + ar_mask * input
941         for n in range(result.size(0)):
942             comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
943             logger(f"test_after  {self.seq2str(result[n])} {comment}")
944             logger(f"truth       {self.seq2str(correct[n])}")
945         ##############################################################
946
947
948 ######################################################################
949
950 import world
951
952
953 class World(Task):
954     def __init__(
955         self,
956         nb_train_samples,
957         nb_test_samples,
958         batch_size,
959         vqae_nb_epochs,
960         device=torch.device("cpu"),
961     ):
962         self.batch_size = batch_size
963         self.device = device
964
965         (
966             train_frames,
967             self.train_actions,
968             test_frames,
969             self.test_actions,
970             self.frame2seq,
971             self.seq2frame,
972         ) = world.create_data_and_processors(
973             nb_train_samples,
974             nb_test_samples,
975             mode="first_last",
976             nb_steps=30,
977             nb_epochs=vqae_nb_epochs,
978             device=device,
979         )
980
981         self.train_input = self.frame2seq(train_frames)
982         self.train_input = self.train_input.reshape(self.train_input.size(0) // 2, -1)
983         self.test_input = self.frame2seq(test_frames)
984         self.test_input = self.test_input.reshape(self.test_input.size(0) // 2, -1)
985
986         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
987
988     def batches(self, split="train", nb_to_use=-1, desc=None):
989         assert split in {"train", "test"}
990         input = self.train_input if split == "train" else self.test_input
991         if nb_to_use > 0:
992             input = input[:nb_to_use]
993         if desc is None:
994             desc = f"epoch-{split}"
995         for batch in tqdm.tqdm(
996             input.split(self.batch_size), dynamic_ncols=True, desc=desc
997         ):
998             yield batch
999
1000     def vocabulary_size(self):
1001         return self.nb_codes
1002
1003     def produce_results(
1004         self, n_epoch, model, result_dir, logger, deterministic_synthesis
1005     ):
1006         l = self.train_input.size(1)
1007         k = torch.arange(l, device=self.device)[None, :]
1008         result = self.test_input[:64].clone()
1009
1010         ar_mask = (k >= l // 2).long().expand_as(result)
1011         result *= 1 - ar_mask
1012
1013         masked_inplace_autoregression(
1014             model,
1015             self.batch_size,
1016             result,
1017             ar_mask,
1018             deterministic_synthesis,
1019             device=self.device,
1020         )
1021
1022         result = result.reshape(result.size(0) * 2, -1)
1023
1024         frames = self.seq2frame(result)
1025         image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
1026         torchvision.utils.save_image(
1027             frames.float() / (world.Box.nb_rgb_levels - 1),
1028             image_name,
1029             nrow=8,
1030             padding=1,
1031             pad_value=0.0,
1032         )
1033         logger(f"wrote {image_name}")
1034
1035
1036 ######################################################################