Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 24 Mar 2024 15:17:47 +0000 (16:17 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 24 Mar 2024 15:17:47 +0000 (16:17 +0100)
tasks.py

index 02c44bb..a4ef557 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1914,7 +1914,8 @@ class Escape(Task):
     ):
         result = self.test_input[:100].clone()
         t = torch.arange(result.size(1), device=result.device)
-        itl = self.height * self.width + 3
+        state_len = self.height * self.width
+        iteration_len = state_len + 3
 
         def ar():
             masked_inplace_autoregression(
@@ -1926,18 +1927,20 @@ class Escape(Task):
                 device=self.device,
             )
 
-        for u in range(itl, result.size(1) - itl + 1, itl):
-            print(f"{itl=} {u=} {result.size(1)=}")
+        for u in range(
+            iteration_len, result.size(1) - iteration_len + 1, iteration_len
+        ):
+            # Put a lookahead reward to -1, sample the next state
             result[:, u - 1] = (-1) + 1 + escape.first_lookahead_rewards_code
-            ar_mask = (t >= u).long() * (t < u + self.height * self.width).long()
+            ar_mask = (t >= u).long() * (t < u + state_len).long()
             ar_mask = ar_mask[None, :]
             ar_mask = ar_mask.expand_as(result)
             result *= 1 - ar_mask
             ar()
+
+            # Put a lookahead reward to +1, sample the action and reward
             result[:, u - 1] = (1) + 1 + escape.first_lookahead_rewards_code
-            ar_mask = (t >= self.height * self.width).long() * (
-                t < self.height * self.width + 2
-            ).long()
+            ar_mask = (t >= state_len).long() * (t < state_len + 2).long()
             ar_mask = ar_mask[None, :]
             ar_mask = ar_mask.expand_as(result)
             result *= 1 - ar_mask