Update.
[mygptrnn.git] / tasks.py
index afad8af..57c6801 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -106,7 +106,7 @@ class SandBox(Task):
             device
         ), self.test_ar_mask.to(device)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
         # A bit of paranoia never hurts
         assert self.nb_codes <= max_nb_codes
@@ -250,7 +250,13 @@ class PicoCLVR(Task):
 
     # Make a list of strings from a tensor
     def detensorize(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+        def id2token(t):
+            try:
+                return self.id2token[t.item()]
+            except KeyError:
+                return "?"
+
+        return [" ".join([id2token(t) for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its
@@ -573,7 +579,7 @@ class Maze(Task):
         )
         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -750,7 +756,7 @@ class Snake(Task):
             self.device,
         )
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -865,7 +871,7 @@ class Stack(Task):
         counts = F.one_hot(counts).sum(0)
         logger(f"test_pop_stack_counts {counts}")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -888,7 +894,10 @@ class Stack(Task):
         def compute_nb_correct(input):
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+
             ar_mask = (result != input).long()
+            result *= 1 - ar_mask
+
             masked_inplace_autoregression(
                 model,
                 self.batch_size,
@@ -923,10 +932,12 @@ class Stack(Task):
         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
         ar_mask = (result != input).long()
 
-        # for n in range(result.size(0)):
-        # logger(
-        # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-        # )
+        for n in range(result.size(0)):
+            logger(
+                f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+            )
+
+        result *= 1 - ar_mask
 
         masked_inplace_autoregression(
             model,
@@ -1067,7 +1078,7 @@ class RPL(Task):
                 s = " ".join(seq)
                 logger(f"example_seq {s}")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1297,7 +1308,7 @@ class Expr(Task):
         self.train_input = self.tensorize(train_sequences)
         self.test_input = self.tensorize(test_sequences)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1448,7 +1459,13 @@ class Grid(Task):
 
     # Make a list of strings from a tensor
     def tensor2str(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+        def id2token(t):
+            try:
+                return self.id2token[t.item()]
+            except KeyError:
+                return "?"
+
+        return [" ".join([id2token(t) for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its
@@ -1473,6 +1490,8 @@ class Grid(Task):
         nb_test_samples,
         batch_size,
         size,
+        nb_shapes,
+        nb_colors,
         logger=None,
         device=torch.device("cpu"),
     ):
@@ -1480,7 +1499,9 @@ class Grid(Task):
 
         self.device = device
         self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(size=size)
+        self.grid_factory = grid.GridFactory(
+            size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+        )
 
         if logger is not None:
             logger(
@@ -1515,11 +1536,13 @@ class Grid(Task):
         self.train_input = self.str2tensor(self.train_descr)
         self.test_input = self.str2tensor(self.test_descr)
 
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
+        if desc is None:
+            desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
             yield self.trim(batch)
 
@@ -1616,13 +1639,15 @@ class QMLP(Task):
             for e in self.test_ref_test_errors:
                 f.write(f"{e}\n")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
+        if desc is None:
+            desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
             yield batch