Update.
[picoclvr.git] / tasks.py
index 5019aed..c7348d5 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1419,6 +1419,131 @@ class Expr(Task):
         ##############################################################
 
 
+######################################################################
+
+import grid
+
+
+class Grid(Task):
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    # Make a list of strings from a tensor
+    def detensorize(self, x):
+        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+    # trim all the tensors in the tuple z to remove as much token from
+    # left and right in the first tensor. If z is a tuple, all its
+    # elements are trimed according to the triming for the first
+    def trim(self, z, token="<nul>"):
+        n = self.token2id[token]
+        if type(z) == tuple:
+            x = z[0]
+            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return tuple([t[:, a:b] for t in z])
+        else:
+            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return z[:, a:b]
+
+    ######################
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.device = device
+        self.batch_size = batch_size
+        self.grid_factory = grid.GridFactory(height=height, width=width)
+
+        if logger is not None:
+            logger(
+                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+            )
+
+        self.train_descr = self.grid_factory.generate_samples(
+            nb_train_samples, lambda r: tqdm.tqdm(r)
+        )
+        self.test_descr = self.grid_factory.generate_samples(
+            nb_test_samples, lambda r: tqdm.tqdm(r)
+        )
+
+        # Build the tokenizer
+        tokens = {}
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        # make this set a sorted list to get the same tensors given
+        # the same descr
+        tokens = list(tokens)
+        tokens.sort()
+        tokens = ["<nul>"] + tokens
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+        self.t_nul = self.token2id["<nul>"]
+        self.t_true = self.token2id["<true>"]
+        self.t_false = self.token2id["<false>"]
+
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield self.trim(batch)
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        correct = self.test_input[:1000]
+        result = correct.clone()
+        ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+        result *= 1 - ar_mask
+
+        for e in self.detensorize(result[:10]):
+            logger(f"test_before {e}")
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        for e in self.detensorize(result[:10]):
+            logger(f"test_after {e}")
+
+        nb_total = ar_mask.sum().item()
+        nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+        logger(f"test_performance {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {nb_correct / nb_total}")
+
+
 ######################################################################
 
 import world