return sequences, ar_mask
def seq2str(self, seq):
- return "".join(self.token_string[x.item()] for x in seq)
+ def decode(x):
+ if x < len(self.token_string):
+ return self.token_string[x]
+ else:
+ return "?"
+
+ return "".join(decode(x.item()) for x in seq)
class ProblemTwoTargets(Problem):