Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:03:24 +0000 (22:03 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:03:24 +0000 (22:03 +0200)
tasks.py

index 332d6c5..5ac78cb 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -101,7 +101,7 @@ class ProblemLevel1(Problem):
         b = a + 1 + self.len_prompt
         sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64)
         nb_operators = torch.randint(self.operators.size(0), (nb,))
-        sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10
+        sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10
         sequences[:, a] = 10
         sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
         sequences[:, b] = 11
@@ -115,7 +115,7 @@ class ProblemLevel1(Problem):
         return sequences, ar_mask
 
     def seq2str(self, seq):
-        return "".join(self.id2char[x.item()] for x in seq)
+        return "".join("0123456789|>"[x.item()] for x in seq)
 
 
 ####################