Update.
[picoclvr.git] / tasks.py
index c7348d5..d787c59 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1426,21 +1426,21 @@ import grid
 
 class Grid(Task):
     # Make a tensor from a list of strings
-    def tensorize(self, descr):
+    def str2tensor(self, descr):
         token_descr = [s.strip().split(" ") for s in descr]
         l = max([len(s) for s in token_descr])
-        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
         return torch.tensor(id_descr, device=self.device)
 
     # Make a list of strings from a tensor
-    def detensorize(self, x):
+    def tensor2str(self, x):
         return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its
     # elements are trimed according to the triming for the first
-    def trim(self, z, token="<nul>"):
+    def trim(self, z, token="#"):
         n = self.token2id[token]
         if type(z) == tuple:
             x = z[0]
@@ -1459,8 +1459,7 @@ class Grid(Task):
         nb_train_samples,
         nb_test_samples,
         batch_size,
-        height,
-        width,
+        size,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1468,7 +1467,7 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(height=height, width=width)
+        self.grid_factory = grid.GridFactory(size=size)
 
         if logger is not None:
             logger(
@@ -1483,7 +1482,7 @@ class Grid(Task):
         )
 
         # Build the tokenizer
-        tokens = {}
+        tokens = set()
         for d in [self.train_descr, self.test_descr]:
             for s in d:
                 for t in s.strip().split(" "):
@@ -1492,16 +1491,16 @@ class Grid(Task):
         # the same descr
         tokens = list(tokens)
         tokens.sort()
-        tokens = ["<nul>"] + tokens
+        tokens = ["#"] + tokens
         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
-        self.t_nul = self.token2id["<nul>"]
-        self.t_true = self.token2id["<true>"]
-        self.t_false = self.token2id["<false>"]
+        self.t_nul = self.token2id["#"]
+        self.t_true = self.token2id["true"]
+        self.t_false = self.token2id["false"]
 
         # Tokenize the train and test sets
-        self.train_input = self.tensorize(self.train_descr)
-        self.test_input = self.tensorize(self.test_descr)
+        self.train_input = self.str2tensor(self.train_descr)
+        self.test_input = self.str2tensor(self.test_descr)
 
     def batches(self, split="train"):
         assert split in {"train", "test"}
@@ -1520,9 +1519,11 @@ class Grid(Task):
         correct = self.test_input[:1000]
         result = correct.clone()
         ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
-        result *= 1 - ar_mask
+        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
 
-        for e in self.detensorize(result[:10]):
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
             logger(f"test_before {e}")
 
         masked_inplace_autoregression(
@@ -1534,14 +1535,18 @@ class Grid(Task):
             device=self.device,
         )
 
-        for e in self.detensorize(result[:10]):
-            logger(f"test_after {e}")
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
+            logger(f"test_after  {e}")
+
+        logger(f"----------------------------------------------------------")
 
         nb_total = ar_mask.sum().item()
         nb_correct = ((correct == result).long() * ar_mask).sum().item()
 
-        logger(f"test_performance {nb_total=} {nb_correct=}")
-        logger(f"main_test_accuracy {nb_correct / nb_total}")
+        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
 
 
 ######################################################################