5 from torch.nn import functional as F
7 ######################################################################
10 def generate_sequence(nb, height=6, width=6, T=10):
11 rnd = torch.rand(nb, height, width)
20 rnd.flatten(1).argmax(dim=1)[:, None]
21 == torch.arange(rnd.flatten(1).size(1))[None, :]
22 ).long().reshape(rnd.size())
23 rnd = rnd * (1 - wall.clamp(max=1))
25 seq = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
27 agent = torch.zeros(seq.size(), dtype=torch.int64)
29 agent_actions = torch.randint(5, (nb, T))
30 monster = torch.zeros(seq.size(), dtype=torch.int64)
31 monster[:, 0, -1, -1] = 1
32 monster_actions = torch.randint(5, (nb, T))
34 all_moves = agent.new(nb, 5, height, width)
35 for t in range(T - 1):
37 all_moves[:, 0] = agent[:, t]
38 all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
39 all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
40 all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
41 all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
42 a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
43 after_move = (all_moves * a).sum(dim=1)
45 (after_move * (1 - wall) * (1 - monster[:, t]))
47 .sum(dim=1)[:, None, None]
50 agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
53 all_moves[:, 0] = monster[:, t]
54 all_moves[:, 1, 1:, :] = monster[:, t, :-1, :]
55 all_moves[:, 2, :-1, :] = monster[:, t, 1:, :]
56 all_moves[:, 3, :, 1:] = monster[:, t, :, :-1]
57 all_moves[:, 4, :, :-1] = monster[:, t, :, 1:]
58 a = F.one_hot(monster_actions[:, t], num_classes=5)[:, :, None, None]
59 after_move = (all_moves * a).sum(dim=1)
61 (after_move * (1 - wall) * (1 - agent[:, t + 1]))
63 .sum(dim=1)[:, None, None]
66 monster[:, t + 1] = collision * monster[:, t] + (1 - collision) * after_move
68 seq += 2 * agent + 3 * monster
70 return seq, agent_actions
73 ######################################################################
76 def seq2str(seq, actions=None):
80 hline = ("+" + "-" * seq.size(-1)) * seq.size(1) + "+" + "\n"
84 for n in range(seq.size(0)):
85 for i in range(seq.size(2)):
89 ["".join([symbols[v.item()] for v in row]) for row in seq[n, :, i]]
97 if actions is not None:
101 ["INESW"[a.item()] + " " * (seq.size(-1) - 1) for a in actions[n]]
112 ######################################################################
114 if __name__ == "__main__":
115 seq, actions = generate_sequence(40, 4, 6, T=20)
117 print(seq2str(seq, actions))