From 5cb7fa9e00bdcf59a4a50bd7deefec416e87fe43 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 26 Mar 2024 13:00:24 +0100 Subject: [PATCH] Update. --- greed.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/greed.py b/greed.py index f7b4cf7..6b271b5 100755 --- a/greed.py +++ b/greed.py @@ -94,9 +94,9 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): agent_actions = torch.randint(5, (nb, T)) rewards = torch.zeros(nb, T, dtype=torch.int64) - monster = torch.zeros(states.size(), dtype=torch.int64) - monster[:, 0, -1, -1] = 1 - monster_actions = torch.randint(5, (nb, T)) + troll = torch.zeros(states.size(), dtype=torch.int64) + troll[:, 0, -1, -1] = 1 + troll_actions = torch.randint(5, (nb, T)) all_moves = agent.new(nb, 5, height, width) for t in range(T - 1): @@ -109,7 +109,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None] after_move = (all_moves * a).sum(dim=1) collision = ( - (after_move * (1 - wall) * (1 - monster[:, t])) + (after_move * (1 - wall) * (1 - troll[:, t])) .flatten(1) .sum(dim=1)[:, None, None] == 0 @@ -117,12 +117,12 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move all_moves.zero_() - all_moves[:, 0] = monster[:, t] - all_moves[:, 1, 1:, :] = monster[:, t, :-1, :] - all_moves[:, 2, :-1, :] = monster[:, t, 1:, :] - all_moves[:, 3, :, 1:] = monster[:, t, :, :-1] - all_moves[:, 4, :, :-1] = monster[:, t, :, 1:] - a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None] + all_moves[:, 0] = troll[:, t] + all_moves[:, 1, 1:, :] = troll[:, t, :-1, :] + all_moves[:, 2, :-1, :] = troll[:, t, 1:, :] + all_moves[:, 3, :, 1:] = troll[:, t, :, :-1] + all_moves[:, 4, :, :-1] = troll[:, t, :, 1:] + a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None] after_move = (all_moves * a).sum(dim=1) collision = ( (after_move * (1 - wall) * (1 - agent[:, t + 1])) @@ -130,13 +130,13 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): .sum(dim=1)[:, None, None] == 0 ).long() - monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move + troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move hit = ( - (agent[:, t + 1, 1:, :] * monster[:, t + 1, :-1, :]).flatten(1).sum(dim=1) - + (agent[:, t + 1, :-1, :] * monster[:, t + 1, 1:, :]).flatten(1).sum(dim=1) - + (agent[:, t + 1, :, 1:] * monster[:, t + 1, :, :-1]).flatten(1).sum(dim=1) - + (agent[:, t + 1, :, :-1] * monster[:, t + 1, :, 1:]).flatten(1).sum(dim=1) + (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1) + + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1) ) hit = (hit > 0).long() @@ -147,7 +147,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): rewards[:, t + 1] = -hit + (1 - hit) * got_coin - states = states + 2 * agent + 3 * monster + 4 * coins + states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll) return states, agent_actions, rewards @@ -271,7 +271,7 @@ def episodes2str( result += hline if ansi_colors: - for u, c in [("$", 31), ("@", 32)]: + for u, c in [("T", 31), ("@", 32), ("$", 34)]: result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m") return result -- 2.20.1