Update.
[mygptrnn.git] / tasks.py
index 727b196..218ff36 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -250,7 +250,13 @@ class PicoCLVR(Task):
 
     # Make a list of strings from a tensor
     def detensorize(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+        def id2token(t):
+            try:
+                return self.id2token[t.item()]
+            except KeyError:
+                return "?"
+
+        return [" ".join([id2token(t) for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its
@@ -888,7 +894,10 @@ class Stack(Task):
         def compute_nb_correct(input):
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+
             ar_mask = (result != input).long()
+            result *= 1 - ar_mask
+
             masked_inplace_autoregression(
                 model,
                 self.batch_size,
@@ -923,10 +932,12 @@ class Stack(Task):
         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
         ar_mask = (result != input).long()
 
-        # for n in range(result.size(0)):
-        # logger(
-        # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-        # )
+        for n in range(result.size(0)):
+            logger(
+                f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+            )
+
+        result *= 1 - ar_mask
 
         masked_inplace_autoregression(
             model,
@@ -1448,7 +1459,13 @@ class Grid(Task):
 
     # Make a list of strings from a tensor
     def tensor2str(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+        def id2token(t):
+            try:
+                return self.id2token[t.item()]
+            except KeyError:
+                return "?"
+
+        return [" ".join([id2token(t) for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its