Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 31 Oct 2023 08:14:35 +0000 (09:14 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 31 Oct 2023 08:14:35 +0000 (09:14 +0100)
picocrafter.py [new file with mode: 0755]

diff --git a/picocrafter.py b/picocrafter.py
new file mode 100755 (executable)
index 0000000..33a00c1
--- /dev/null
@@ -0,0 +1,437 @@
+#!/usr/bin/env python
+
+#########################################################################
+# This program is free software: you can redistribute it and/or modify  #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation.                         #
+#                                                                       #
+# This program is distributed in the hope that it will be useful, but   #
+# WITHOUT ANY WARRANTY; without even the implied warranty of            #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
+# General Public License for more details.                              #
+#                                                                       #
+# You should have received a copy of the GNU General Public License     #
+# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
+#                                                                       #
+# Written by and Copyright (C) Francois Fleuret                         #
+# Contact <francois.fleuret@unige.ch> for comments & bug reports        #
+#########################################################################
+
+# This is a tiny rogue-like environment implemented with tensor
+# operations, that runs in batches efficiently on a GPU. On a RTX4090
+# it can initialize ~20k environments per second and run ~40k
+# iterations.
+#
+# The agent "@" moves in a maze-like grid with random walls "#". There
+# are five actions: move NESW or do not move.
+#
+# There are monsters "$" moving randomly. The agent gets hit by every
+# monster present in one of the 4 direct neighborhoods at the end of
+# the moves, each hit results in a rewards of -1.
+#
+# The agent starts with 5 life points, each hit costs it 1pt, when it
+# gets to 0 it dies, gets a reward of -10 and the episode is over. At
+# every step it recovers 1/20th of a life point, with a maximum of
+# 5pt.
+#
+# The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
+# "B", "C"). They keys can only be used in sequence: initially the
+# agent can move only to free spaces, or to the "a", in which case it
+# now carries it, and can move to free spaces or the "A". When it
+# moves to the "A", it gets a reward and loses the "a", but can now
+# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C".
+
+######################################################################
+
+import torch
+
+from torch.nn.functional import conv2d
+
+######################################################################
+
+
+class PicroCrafterEngine:
+    def __init__(
+        self,
+        world_height=27,
+        world_width=27,
+        nb_walls=27,
+        margin=2,
+        view_height=5,
+        view_width=5,
+        device=torch.device("cpu"),
+    ):
+        assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0
+        assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0
+
+        self.device = device
+
+        self.world_height = world_height
+        self.world_width = world_width
+        self.margin = margin
+        self.view_height = view_height
+        self.view_width = view_width
+        self.nb_walls = nb_walls
+        self.life_level_max = 5
+        self.life_level_gain_100th = 5
+        self.reward_per_hit = -1
+        self.reward_death = -10
+
+        self.tokens = " +#@$aAbBcC"
+        self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
+
+        self.next_object = dict(
+            [
+                (self.token2id[s], self.token2id[t])
+                for (s, t) in [
+                    ("a", "A"),
+                    ("A", "b"),
+                    ("b", "B"),
+                    ("B", "c"),
+                    ("c", "C"),
+                ]
+            ]
+        )
+
+        self.object_reward = dict(
+            [
+                (self.token2id[t], r)
+                for (t, r) in [
+                    ("a", 0),
+                    ("A", 1),
+                    ("b", 0),
+                    ("B", 1),
+                    ("c", 0),
+                    ("C", 10),
+                ]
+            ]
+        )
+
+        self.acessible_object_to_inventory = dict(
+            [
+                (self.token2id[s], self.token2id[t])
+                for (s, t) in [
+                    ("a", " "),
+                    ("A", "a"),
+                    ("b", " "),
+                    ("B", "b"),
+                    ("c", " "),
+                    ("C", " "),
+                ]
+            ]
+        )
+
+    def reset(self, nb_agents):
+        self.worlds = self.create_worlds(
+            nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
+        ).to(self.device)
+        self.life_level_in_100th = torch.full(
+            (nb_agents,), self.life_level_max * 100, device=self.device
+        )
+        self.accessible_object = torch.full(
+            (nb_agents,), self.token2id["a"], device=self.device
+        )
+
+    def create_mazes(self, nb, height, width, nb_walls):
+        m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device)
+        m[:, 0, :] = 1
+        m[:, -1, :] = 1
+        m[:, :, 0] = 1
+        m[:, :, -1] = 1
+
+        i = torch.arange(height, device=m.device)[None, :, None]
+        j = torch.arange(width, device=m.device)[None, None, :]
+
+        for _ in range(nb_walls):
+            q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * (
+                (1 - m) * (i % 2 == 0) * (j % 2 == 0)
+            ).flatten(1)
+            q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size())
+            a = q[:, None].expand(-1, 4, -1, -1).clone()
+            a[:, 0, :-1, :] += q[:, 1:, :]
+            a[:, 0, :-2, :] += q[:, 2:, :]
+            a[:, 1, 1:, :] += q[:, :-1, :]
+            a[:, 1, 2:, :] += q[:, :-2, :]
+            a[:, 2, :, :-1] += q[:, :, 1:]
+            a[:, 2, :, :-2] += q[:, :, 2:]
+            a[:, 3, :, 1:] += q[:, :, :-1]
+            a[:, 3, :, 2:] += q[:, :, :-2]
+            a = a[
+                torch.arange(a.size(0), device=a.device),
+                torch.randint(4, (a.size(0),), device=a.device),
+            ]
+            m = (m + q + a).clamp(max=1)
+
+        return m
+
+    def create_worlds(self, nb, height, width, nb_walls, margin=2):
+        margin -= 1  # The maze adds a wall all around
+        m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
+        q = m.flatten(1)
+        z = "@aAbBcC$$$$$"  # What to add to the maze
+        u = torch.rand(q.size(), device=q.device) * (1 - q)
+        r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
+
+        q *= self.token2id["#"]
+        q[
+            torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
+        ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :]
+
+        if margin > 0:
+            r = m.new_full(
+                (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
+                self.token2id["+"],
+            )
+            r[:, margin:-margin, margin:-margin] = m
+            m = r
+        return m
+
+    def nb_actions(self):
+        return 5
+
+    def nb_view_tokens(self):
+        return len(self.tokens)
+
+    def min_max_reward(self):
+        return (
+            min(4 * self.reward_per_hit, self.reward_death),
+            max(self.object_reward.values()),
+        )
+
+    def step(self, actions):
+        a = (self.worlds == self.token2id["@"]).nonzero()
+        self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "]
+        s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
+        b = a.clone()
+        b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
+
+        # position is empty
+        o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
+        # or it is the next accessible object
+        q = (
+            self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
+        ).long()
+        o = (o + q).clamp(max=1)[:, None]
+        b = (1 - o) * a + o * b
+        self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
+
+        nb_hits = self.monster_moves()
+
+        alive_before = self.life_level_in_100th > 0
+        self.life_level_in_100th[alive_before] = (
+            self.life_level_in_100th[alive_before]
+            + self.life_level_gain_100th
+            - nb_hits[alive_before] * 100
+        ).clamp(max=self.life_level_max * 100)
+        alive_after = self.life_level_in_100th > 0
+        self.worlds[torch.logical_not(alive_after)] = self.token2id["#"]
+        reward = nb_hits * self.reward_per_hit
+
+        for i in range(q.size(0)):
+            if q[i] == 1:
+                reward[i] += self.object_reward[self.accessible_object[i].item()]
+                self.accessible_object[i] = self.next_object[
+                    self.accessible_object[i].item()
+                ]
+
+        reward = (
+            reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death
+        )
+        inventory = torch.tensor(
+            [
+                self.acessible_object_to_inventory[s.item()]
+                for s in self.accessible_object
+            ]
+        )
+
+        reward[torch.logical_not(alive_before)] = 0
+        return reward, inventory, self.life_level_in_100th // 100
+
+    def monster_moves(self):
+        # Current positions of the monsters
+        m = (self.worlds == self.token2id["$"]).long().flatten(1)
+
+        # Total number of monsters
+        n = m.sum(-1).max()
+
+        # Create a tensor with one channel per monster
+        r = (
+            (torch.rand(m.size(), device=m.device) * m)
+            .sort(dim=-1, descending=True)
+            .indices[:, :n]
+        )
+        o = m.new_zeros((m.size(0), n) + m.size()[1:])
+        i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1))
+        j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1)
+        o[i, j, r] = 1
+        o = o * m[:, None]
+
+        # Create the tensor of possible motions
+        o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:])
+        move_kernel = torch.tensor(
+            [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device
+        )
+
+        p = (
+            conv2d(
+                o.view(
+                    o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1)
+                ).float(),
+                move_kernel,
+                padding=1,
+            ).view(o.size())
+            == 1.0
+        ).long()
+
+        # Let's do the moves per say
+        i = torch.arange(self.worlds.size(0), device=self.worlds.device)[
+            :, None
+        ].expand_as(r)
+
+        for n in range(p.size(1)):
+            u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
+            q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n]
+            r = (
+                (q * torch.rand(q.size(), device=q.device))
+                .sort(dim=-1, descending=True)
+                .indices[:, :1]
+            )
+            self.worlds.flatten(1)[i, u] = self.token2id[" "]
+            self.worlds.flatten(1)[i, r] = self.token2id["$"]
+
+        nb_hits = (
+            (
+                conv2d(
+                    (self.worlds == self.token2id["$"]).float()[:, None],
+                    move_kernel,
+                    padding=1,
+                )
+                .long()
+                .squeeze(1)
+                * (self.worlds == self.token2id["@"]).long()
+            )
+            .flatten(1)
+            .sum(-1)
+        )
+
+        return nb_hits
+
+    def views(self):
+        i_height, i_width = (
+            self.view_height - 2 * self.margin,
+            self.view_width - 2 * self.margin,
+        )
+        a = (self.worlds == self.token2id["@"]).nonzero()
+        y = i_height * ((a[:, 1] - self.margin) // i_height)
+        x = i_width * ((a[:, 2] - self.margin) // i_width)
+        n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
+        i = (
+            torch.arange(self.view_height, device=a.device)[None, :, None]
+            + y[:, None, None]
+        ).expand_as(n)
+        j = (
+            torch.arange(self.view_width, device=a.device)[None, None, :]
+            + x[:, None, None]
+        ).expand_as(n)
+        v = self.worlds.new_full(
+            (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"]
+        )
+
+        v[a[:, 0]] = self.worlds[n, i, j]
+
+        return v
+
+    def print_worlds(
+        self, src=None, comments=[], width=None, printer=print, ansi_term=False
+    ):
+        if src is None:
+            src = self.worlds
+
+        if width is None:
+            width = src.size(2)
+
+        def token(n):
+            n = n.item()
+            if n in self.id2token:
+                return self.id2token[n]
+            else:
+                return "?"
+
+        for k in range(src.size(1)):
+            s = ["".join([token(n) for n in m[k]]) for m in src]
+            s = [r + " " * (width - len(r)) for r in s]
+            if ansi_term:
+
+                def colorize(x):
+                    for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
+                        (x, 36) for x in "aAbBcC"
+                    ]:
+                        x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+                    return x
+
+                s = [colorize(x) for x in s]
+            printer(" | ".join(s))
+
+        s = [c + " " * (width - len(c)) for c in comments]
+        printer(" | ".join(s))
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import os, time
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    ansi_term = False
+    # nb_agents, nb_iter, display = 1000, 100, False
+    nb_agents, nb_iter, display = 3, 10000, True
+    ansi_term = True
+
+    start_time = time.perf_counter()
+    engine = PicroCrafterEngine(
+        world_height=27,
+        world_width=27,
+        nb_walls=35,
+        view_height=9,
+        view_width=9,
+        margin=4,
+        device=device,
+    )
+
+    engine.reset(nb_agents)
+
+    print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s")
+
+    start_time = time.perf_counter()
+
+    for k in range(nb_iter):
+        action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+        rewards, inventories, life_levels = engine.step(
+            torch.randint(engine.nb_actions(), (nb_agents,), device=device)
+        )
+
+        if display:
+            os.system("clear")
+            engine.print_worlds(
+                ansi_term=ansi_term,
+            )
+            print()
+            engine.print_worlds(
+                src=engine.views(),
+                comments=[
+                    f"L{p}I{engine.id2token[s.item()]}R{r}"
+                    for p, s, r in zip(life_levels, inventories, rewards)
+                ],
+                width=engine.world_width,
+                ansi_term=ansi_term,
+            )
+            time.sleep(0.5)
+
+        if (life_levels > 0).long().sum() == 0:
+            break
+
+    print(
+        f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"
+    )