+ t = torch.arange(input.size(1), device=input.device)[None, :]
+ u = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ lr_mask = (t <= u).long() * (
+ t % self.it_len == self.index_lookahead_reward
+ ).long()
+
+ input = lr_mask * escape.lookahead_reward2code(2) + (1 - lr_mask) * input