Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 8 Jul 2023 09:17:25 +0000 (11:17 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 8 Jul 2023 09:17:25 +0000 (11:17 +0200)
tasks.py

index 912b405..8fe89be 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -748,18 +748,21 @@ class Stack(Task):
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
             ar_mask = (result != input).long()
-            for n in range(result.size(0)):
-                logger(
-                    f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-                )
-                masked_inplace_autoregression(
-                    model,
-                    self.batch_size,
-                    result,
-                    ar_mask,
-                    deterministic_synthesis,
-                    device=self.device,
-                )
+
+            # for n in range(result.size(0)):
+            # logger(
+            # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+            # )
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
             for n in range(result.size(0)):
                 logger(
                     f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
@@ -936,16 +939,19 @@ class Expr(Task):
             result = input.clone()
             ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
             result = (1 - ar_mask) * result + ar_mask * self.filler
-            for n in range(result.size(0)):
-                logger(f"test_before {self.seq2str(result[n])}")
-                masked_inplace_autoregression(
-                    model,
-                    self.batch_size,
-                    result,
-                    ar_mask,
-                    deterministic_synthesis,
-                    device=self.device,
-                )
+
+            # for n in range(result.size(0)):
+            # logger(f"test_before {self.seq2str(result[n])}")
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
             correct = (1 - ar_mask) * self.space + ar_mask * input
             for n in range(result.size(0)):
                 comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""