3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation. #
8 # This program is distributed in the hope that it will be useful, but #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
11 # General Public License for more details. #
13 # You should have received a copy of the GNU General Public License #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
16 # Written by and Copyright (C) Francois Fleuret #
17 # Contact <francois.fleuret@unige.ch> for comments & bug reports #
18 #########################################################################
20 # This is a tiny rogue-like environment implemented with tensor
21 # operations, that runs in batches efficiently on a GPU. On a RTX4090
22 # it can initialize ~20k environments per second and run ~40k
25 # The agent "@" moves in a maze-like grid with random walls "#". There
26 # are five actions: move NESW or do not move.
28 # There are monsters "$" moving randomly. The agent gets hit by every
29 # monster present in one of the 4 direct neighborhoods at the end of
30 # the moves, each hit results in a rewards of -1.
32 # The agent starts with 5 life points, each hit costs it 1pt, when it
33 # gets to 0 it dies, gets a reward of -10 and the episode is over. At
34 # every step it recovers 1/20th of a life point, with a maximum of
37 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
38 # "B", "C"). The keys and vault can only be used in sequence:
39 # initially the agent can move only to free spaces, or to the "a", in
40 # which case the key is removed from the environment and the agent now
41 # carries it, and can move to free spaces or the "A". When it moves to
42 # the "A", it gets a reward, loses the "a", the "A" is removed from
43 # the environment, but can now move to the "b", etc. Rewards are 1 for
44 # "A" and "B" and 10 for "C".
46 ######################################################################
50 from torch.nn.functional import conv2d
52 ######################################################################
55 class PicroCrafterEngine:
64 device=torch.device("cpu"),
66 assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0
67 assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0
71 self.world_height = world_height
72 self.world_width = world_width
74 self.view_height = view_height
75 self.view_width = view_width
76 self.nb_walls = nb_walls
77 self.life_level_max = 5
78 self.life_level_gain_100th = 5
79 self.reward_per_hit = -1
80 self.reward_death = -10
82 self.tokens = " +#@$aAbBcC"
83 self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
84 self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
86 self.next_object = dict(
88 (self.token2id[s], self.token2id[t])
100 self.object_reward = dict(
102 (self.token2id[t], r)
114 self.acessible_object_to_inventory = dict(
116 (self.token2id[s], self.token2id[t])
128 def reset(self, nb_agents):
129 self.worlds = self.create_worlds(
130 nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
132 self.life_level_in_100th = torch.full(
133 (nb_agents,), self.life_level_max * 100, device=self.device
135 self.accessible_object = torch.full(
136 (nb_agents,), self.token2id["a"], device=self.device
139 def create_mazes(self, nb, height, width, nb_walls):
140 m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device)
146 i = torch.arange(height, device=m.device)[None, :, None]
147 j = torch.arange(width, device=m.device)[None, None, :]
149 for _ in range(nb_walls):
150 q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * (
151 (1 - m) * (i % 2 == 0) * (j % 2 == 0)
153 q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size())
154 a = q[:, None].expand(-1, 4, -1, -1).clone()
155 a[:, 0, :-1, :] += q[:, 1:, :]
156 a[:, 0, :-2, :] += q[:, 2:, :]
157 a[:, 1, 1:, :] += q[:, :-1, :]
158 a[:, 1, 2:, :] += q[:, :-2, :]
159 a[:, 2, :, :-1] += q[:, :, 1:]
160 a[:, 2, :, :-2] += q[:, :, 2:]
161 a[:, 3, :, 1:] += q[:, :, :-1]
162 a[:, 3, :, 2:] += q[:, :, :-2]
164 torch.arange(a.size(0), device=a.device),
165 torch.randint(4, (a.size(0),), device=a.device),
167 m = (m + q + a).clamp(max=1)
171 def create_worlds(self, nb, height, width, nb_walls, margin=2):
172 margin -= 1 # The maze adds a wall all around
173 m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
175 z = "@aAbBcC$$$$$" # What to add to the maze
176 u = torch.rand(q.size(), device=q.device) * (1 - q)
177 r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
179 q *= self.token2id["#"]
181 torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
182 ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :]
186 (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
189 r[:, margin:-margin, margin:-margin] = m
193 def nb_actions(self):
196 def nb_view_tokens(self):
197 return len(self.tokens)
199 def min_max_reward(self):
201 min(4 * self.reward_per_hit, self.reward_death),
202 max(self.object_reward.values()),
205 def step(self, actions):
206 a = (self.worlds == self.token2id["@"]).nonzero()
207 self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "]
208 s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
210 b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
213 o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
214 # or it is the next accessible object
216 self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
218 o = (o + q).clamp(max=1)[:, None]
219 b = (1 - o) * a + o * b
220 self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
222 nb_hits = self.monster_moves()
224 alive_before = self.life_level_in_100th > 0
225 self.life_level_in_100th[alive_before] = (
226 self.life_level_in_100th[alive_before]
227 + self.life_level_gain_100th
228 - nb_hits[alive_before] * 100
229 ).clamp(max=self.life_level_max * 100)
230 alive_after = self.life_level_in_100th > 0
231 self.worlds[torch.logical_not(alive_after)] = self.token2id["#"]
232 reward = nb_hits * self.reward_per_hit
234 for i in range(q.size(0)):
236 reward[i] += self.object_reward[self.accessible_object[i].item()]
237 self.accessible_object[i] = self.next_object[
238 self.accessible_object[i].item()
242 reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death
244 inventory = torch.tensor(
246 self.acessible_object_to_inventory[s.item()]
247 for s in self.accessible_object
251 self.life_level_in_100th = (
252 self.life_level_in_100th
253 * (self.accessible_object != self.token2id[" "]).long()
256 reward[torch.logical_not(alive_before)] = 0
257 return reward, inventory, self.life_level_in_100th // 100
259 def monster_moves(self):
260 # Current positions of the monsters
261 m = (self.worlds == self.token2id["$"]).long().flatten(1)
263 # Total number of monsters
266 # Create a tensor with one channel per monster
268 (torch.rand(m.size(), device=m.device) * m)
269 .sort(dim=-1, descending=True)
272 o = m.new_zeros((m.size(0), n) + m.size()[1:])
273 i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1))
274 j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1)
278 # Create the tensor of possible motions
279 o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:])
280 move_kernel = torch.tensor(
281 [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device
287 o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1)
295 # Let's do the moves per say
296 i = torch.arange(self.worlds.size(0), device=self.worlds.device)[
300 for n in range(p.size(1)):
301 u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
302 q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n]
304 (q * torch.rand(q.size(), device=q.device))
305 .sort(dim=-1, descending=True)
308 self.worlds.flatten(1)[i, u] = self.token2id[" "]
309 self.worlds.flatten(1)[i, r] = self.token2id["$"]
314 (self.worlds == self.token2id["$"]).float()[:, None],
320 * (self.worlds == self.token2id["@"]).long()
329 i_height, i_width = (
330 self.view_height - 2 * self.margin,
331 self.view_width - 2 * self.margin,
333 a = (self.worlds == self.token2id["@"]).nonzero()
334 y = i_height * ((a[:, 1] - self.margin) // i_height)
335 x = i_width * ((a[:, 2] - self.margin) // i_width)
336 n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
338 torch.arange(self.view_height, device=a.device)[None, :, None]
342 torch.arange(self.view_width, device=a.device)[None, None, :]
345 v = self.worlds.new_full(
346 (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"]
349 v[a[:, 0]] = self.worlds[n, i, j]
354 self, src=None, comments=[], width=None, printer=print, ansi_term=False
364 if n in self.id2token:
365 return self.id2token[n]
369 for k in range(src.size(1)):
370 s = ["".join([token(n) for n in m[k]]) for m in src]
371 s = [r + " " * (width - len(r)) for r in s]
375 for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
376 (x, 36) for x in "aAbBcC"
378 x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m")
381 s = [colorize(x) for x in s]
382 printer(" | ".join(s))
384 s = [c + " " * (width - len(c)) for c in comments]
385 printer(" | ".join(s))
388 ######################################################################
390 if __name__ == "__main__":
393 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
396 # nb_agents, nb_iter, display = 1000, 100, False
397 nb_agents, nb_iter, display = 3, 10000, True
400 start_time = time.perf_counter()
401 engine = PicroCrafterEngine(
411 engine.reset(nb_agents)
413 print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
415 start_time = time.perf_counter()
417 for k in range(nb_iter):
418 action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
419 rewards, inventories, life_levels = engine.step(
420 torch.randint(engine.nb_actions(), (nb_agents,), device=device)
432 f"L{p}I{engine.id2token[s.item()]}R{r}"
433 for p, s, r in zip(life_levels, inventories, rewards)
435 width=engine.world_width,
440 if (life_levels > 0).long().sum() == 0:
444 f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"