3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 from torch.nn import functional as F
12 ######################################################################
17 nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN
20 first_actions_code = first_states_code + nb_states_codes
21 first_rewards_code = first_actions_code + nb_actions_codes
22 first_lookahead_rewards_code = first_rewards_code + nb_rewards_codes
23 nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes
25 ######################################################################
29 return r + first_states_code
33 return r - first_states_code
37 return r + first_actions_code
41 return r - first_actions_code
45 return r + 1 + first_rewards_code
49 return r - first_rewards_code - 1
52 def lookahead_reward2code(r):
53 # -1, 0, +1 or 2 for UNKNOWN
54 return r + 1 + first_lookahead_rewards_code
57 def code2lookahead_reward(r):
58 return r - first_lookahead_rewards_code - 1
61 ######################################################################
64 def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
65 rnd = torch.rand(nb, height, width)
71 for k in range(nb_walls):
73 rnd.flatten(1).argmax(dim=1)[:, None]
74 == torch.arange(rnd.flatten(1).size(1))[None, :]
75 ).long().reshape(rnd.size())
77 rnd = rnd * (1 - wall.clamp(max=1))
79 rnd = torch.rand(nb, height, width)
80 rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting
82 coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
83 rnd = rnd * (1 - wall.clamp(max=1))
84 for k in range(nb_coins):
85 coins[:, 0] = coins[:, 0] + (
86 rnd.flatten(1).argmax(dim=1)[:, None]
87 == torch.arange(rnd.flatten(1).size(1))[None, :]
88 ).long().reshape(rnd.size())
90 rnd = rnd * (1 - coins[:, 0].clamp(max=1))
92 states = wall[:, None, :, :].expand(-1, T, -1, -1).clone()
94 agent = torch.zeros(states.size(), dtype=torch.int64)
96 agent_actions = torch.randint(5, (nb, T))
97 rewards = torch.zeros(nb, T, dtype=torch.int64)
99 troll = torch.zeros(states.size(), dtype=torch.int64)
100 troll[:, 0, -1, -1] = 1
101 troll_actions = torch.randint(5, (nb, T))
103 all_moves = agent.new(nb, 5, height, width)
104 for t in range(T - 1):
106 all_moves[:, 0] = agent[:, t]
107 all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
108 all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
109 all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
110 all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
111 a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
112 after_move = (all_moves * a).sum(dim=1)
114 (after_move * (1 - wall) * (1 - troll[:, t]))
116 .sum(dim=1)[:, None, None]
119 agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
122 all_moves[:, 0] = troll[:, t]
123 all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
124 all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
125 all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
126 all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
127 a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
128 after_move = (all_moves * a).sum(dim=1)
130 (after_move * (1 - wall) * (1 - agent[:, t + 1]))
132 .sum(dim=1)[:, None, None]
135 troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
138 (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
139 + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :]).flatten(1).sum(dim=1)
140 + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1]).flatten(1).sum(dim=1)
141 + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:]).flatten(1).sum(dim=1)
143 hit = (hit > 0).long()
145 # assert hit.min() == 0 and hit.max() <= 1
147 got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
148 coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
150 rewards[:, t + 1] = -hit + (1 - hit) * got_coin
152 states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
154 return states, agent_actions, rewards
157 ######################################################################
160 def episodes2seq(states, actions, rewards):
161 neg = rewards.new_zeros(rewards.size())
162 pos = rewards.new_zeros(rewards.size())
163 for t in range(neg.size(1) - 1):
164 neg[:, t] = rewards[:, t:].min(dim=-1).values
165 pos[:, t] = rewards[:, t:].max(dim=-1).values
166 s = (neg < 0).long() * neg + (neg >= 0).long() * pos
170 lookahead_reward2code(s[:, :, None]),
171 state2code(states.flatten(2)),
172 action2code(actions[:, :, None]),
173 reward2code(rewards[:, :, None]),
179 def seq2episodes(seq, height, width):
180 seq = seq.reshape(seq.size(0), -1, height * width + 3)
181 lookahead_rewards = code2lookahead_reward(seq[:, :, 0])
182 states = code2state(seq[:, :, 1 : height * width + 1])
183 states = states.reshape(states.size(0), states.size(1), height, width)
184 actions = code2action(seq[:, :, height * width + 1])
185 rewards = code2reward(seq[:, :, height * width + 2])
186 return lookahead_rewards, states, actions, rewards
191 if t >= first_states_code and t < first_states_code + nb_states_codes:
192 return " #@T$"[t - first_states_code]
193 elif t >= first_actions_code and t < first_actions_code + nb_actions_codes:
194 return "ISNEW"[t - first_actions_code]
195 elif t >= first_rewards_code and t < first_rewards_code + nb_rewards_codes:
196 return "-0+"[t - first_rewards_code]
198 t >= first_lookahead_rewards_code
199 and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
201 return "n.pU"[t - first_lookahead_rewards_code]
205 return ["".join([token2str(x.item()) for x in row]) for row in seq]
208 ######################################################################
212 lookahead_rewards, states, actions, rewards, unicode=False, ansi_colors=False
216 # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
217 vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
220 vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
222 hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
226 for n in range(states.size(0)):
230 return "?" if v < 0 or v >= len(symbols) else symbols[v]
232 for i in range(states.size(2)):
236 ["".join([state_symbol(v) for v in row]) for row in states[n, :, i]]
242 # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
244 def status_bar(a, r, lr=None):
245 a, r = a.item(), r.item()
246 sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
247 sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
252 sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
257 + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
266 for a, r, lr in zip(actions[n], rewards[n], lookahead_rewards[n])
276 for u, c in [("T", 31), ("@", 32), ("$", 34)]:
277 result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
282 ######################################################################
285 def save_seq_as_anim_script(seq, filename):
286 it_len = height * width + 3
289 seq.reshape(seq.size(0), -1, it_len)
291 .reshape(T, seq.size(0), -1)
294 with open(filename, "w") as f:
297 f.write("cat << EOF\n")
298 # for i in range(seq.size(2)):
299 # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], height, width)
300 lr, s, a, r = seq2episodes(
301 seq[t : t + 1, :].reshape(5, 10 * it_len), height, width
303 f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
305 f.write("sleep 0.25\n")
306 print(f"Saved {filename}")
309 if __name__ == "__main__":
310 nb, height, width, T, nb_walls = 6, 5, 7, 10, 5
311 states, actions, rewards = generate_episodes(nb, height, width, T, nb_walls)
312 seq = episodes2seq(states, actions, rewards)
313 lr, s, a, r = seq2episodes(seq, height, width)
314 print(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
317 # for s in seq2str(seq):
321 states, actions, rewards = generate_episodes(
322 nb=nb, height=height, width=width, T=T, nb_walls=3
324 seq = episodes2seq(states, actions, rewards)
325 save_seq_as_anim_script(seq, "anim.sh")