From f6651daffafd4f1f8eb51e4b6feb9c3e93c5c860 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 24 Mar 2024 08:15:46 +0100 Subject: [PATCH] Update. --- evasion.py | 58 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/evasion.py b/evasion.py index 4efa4b3..a06213e 100755 --- a/evasion.py +++ b/evasion.py @@ -27,6 +27,8 @@ def generate_sequence(nb, height=6, width=6, T=10): agent = torch.zeros(seq.size(), dtype=torch.int64) agent[:, 0, 0, 0] = 1 agent_actions = torch.randint(5, (nb, T)) + rewards = torch.zeros(nb, T, dtype=torch.int64) + monster = torch.zeros(seq.size(), dtype=torch.int64) monster[:, 0, -1, -1] = 1 monster_actions = torch.randint(5, (nb, T)) @@ -65,44 +67,64 @@ def generate_sequence(nb, height=6, width=6, T=10): ).long() monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move + hit = ( + (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1) + ) + hit = (hit > 0).long() + + assert hit.min() == 0 and hit.max() <= 1 + + rewards[:, t] = -hit + (1 - hit) * agent[:, t + 1, -1, -1] + seq += 2 * agent + 3 * monster - return seq, agent_actions + return seq, agent_actions, rewards ###################################################################### -def seq2str(seq, actions=None): +def seq2str(seq, actions, rewards): # symbols=" #@$" + # vert, hori, cross, thin_hori = "|", "-", "+", "-" + symbols = " █@$" + vert, hori, cross, thin_hori = "║", "═", "╬", "─" + vert, hori, cross, thin_hori = "┃", "━", "╋", "─" - hline = ("+" + "-" * seq.size(-1)) * seq.size(1) + "+" + "\n" + # hline = ("+" + "-" * seq.size(-1)) * seq.size(1) + "+" + "\n" + hline = (cross + hori * seq.size(-1)) * seq.size(1) + cross + "\n" result = hline for n in range(seq.size(0)): for i in range(seq.size(2)): result += ( - "|" - + "|".join( + vert + + vert.join( ["".join([symbols[v.item()] for v in row]) for row in seq[n, :, i]] ) - + "|" + + vert + "\n" ) - result += hline + # result += hline + result += (vert + thin_hori * seq.size(-1)) * seq.size(1) + vert + "\n" - if actions is not None: - result += ( - "|" - + "|".join( - ["INESW"[a.item()] + " " * (seq.size(-1) - 1) for a in actions[n]] - ) - + "|" - + "\n" - ) + def status_bar(a, r): + a = "INESW"[a.item()] + r = f"{r.item()}" + return a + " " * (seq.size(-1) - len(a) - len(r)) + r + + result += ( + vert + + vert.join([status_bar(a, r) for a, r in zip(actions[n], rewards[n])]) + + vert + + "\n" + ) result += hline @@ -112,6 +134,6 @@ def seq2str(seq, actions=None): ###################################################################### if __name__ == "__main__": - seq, actions = generate_sequence(40, 4, 6, T=20) + seq, actions, rewards = generate_sequence(10, 4, 6, T=20) - print(seq2str(seq, actions)) + print(seq2str(seq, actions, rewards)) -- 2.20.1