Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 4 Jul 2023 16:08:55 +0000 (18:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 4 Jul 2023 16:08:55 +0000 (18:08 +0200)
main.py

diff --git a/main.py b/main.py
index beafc19..b907e60 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1091,7 +1091,7 @@ class TaskExpr(Task):
                 result = input.clone()
                 filler, space = self.char2id["#"], self.char2id[" "]
                 ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
-                result = (1 - ar_mask) * result + filler * ar_mask
+                result = (1 - ar_mask) * result + ar_mask * filler
                 masked_inplace_autoregression(
                     model, self.batch_size, result, ar_mask, device=self.device
                 )
@@ -1113,16 +1113,19 @@ class TaskExpr(Task):
             result = input.clone()
             filler, space = self.char2id["#"], self.char2id[" "]
             ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
-            result = (1 - ar_mask) * result + filler * ar_mask
+            result = (1 - ar_mask) * result + ar_mask * filler
             for n in range(result.size(0)):
                 s = "".join([self.id2char[k.item()] for k in result[n]])
                 log_string(f"test_before {s}")
             masked_inplace_autoregression(
                 model, self.batch_size, result, ar_mask, device=self.device
             )
+            correct = (1 - ar_mask) * space + ar_mask * input
             for n in range(result.size(0)):
                 s = "".join([self.id2char[k.item()] for k in result[n]])
                 log_string(f"test_after  {s}")
+                s = "".join([self.id2char[k.item()] for k in correct[n]])
+                log_string(f"correct     {s}")
             ##############################################################
 
             model.train(t)