+ hit = (
+ (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
+ + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
+ )
+ hit = (hit > 0).long()
+
+ assert hit.min() == 0 and hit.max() <= 1
+
+ rewards[:, t] = -hit + (1 - hit) * agent[:, t + 1, -1, -1]
+