Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 2 Nov 2023 16:26:49 +0000 (17:26 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 2 Nov 2023 16:26:49 +0000 (17:26 +0100)
picocrafter.py

index 7810b67..ef861bc 100755 (executable)
@@ -52,6 +52,29 @@ from torch.nn.functional import conv2d
 ######################################################################
 
 
+def add_ansi_coloring(s):
+    if type(s) is list:
+        return [add_ansi_coloring(x) for x in s]
+
+    for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
+        s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+
+    return s
+
+
+def fusion_multi_lines(l, width_min=0):
+    l = [x if type(x) is list else [str(x)] for x in l]
+
+    def f(o, h):
+        w = max(width_min, max([len(r) for r in o]))
+        return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o]
+
+    h = max([len(x) for x in l])
+    l = [f(o, h) for o in l]
+
+    return "\n".join(["|".join([o[k] for o in l]) for k in range(h)])
+
+
 class PicroCrafterEngine:
     def __init__(
         self,
@@ -79,27 +102,29 @@ class PicroCrafterEngine:
         self.reward_per_hit = -1
         self.reward_death = -10
 
-        self.tokens = " +#@$aAbBcC."
-        self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
-        self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
+        self.tiles = " +#@$aAbBcC-" + "".join(
+            [str(n) for n in range(self.life_level_max + 1)]
+        )
+        self.tile2id = dict([(t, n) for n, t in enumerate(self.tiles)])
+        self.id2tile = dict([(n, t) for n, t in enumerate(self.tiles)])
 
         self.next_object = dict(
             [
-                (self.token2id[s], self.token2id[t])
+                (self.tile2id[s], self.tile2id[t])
                 for (s, t) in [
                     ("a", "A"),
                     ("A", "b"),
                     ("b", "B"),
                     ("B", "c"),
                     ("c", "C"),
-                    ("C", "."),
+                    ("C", "-"),
                 ]
             ]
         )
 
         self.object_reward = dict(
             [
-                (self.token2id[t], r)
+                (self.tile2id[t], r)
                 for (t, r) in [
                     ("a", 0),
                     ("A", 1),
@@ -113,7 +138,7 @@ class PicroCrafterEngine:
 
         self.accessible_object_to_inventory = dict(
             [
-                (self.token2id[s], self.token2id[t])
+                (self.tile2id[s], self.tile2id[t])
                 for (s, t) in [
                     ("a", " "),
                     ("A", "a"),
@@ -121,7 +146,7 @@ class PicroCrafterEngine:
                     ("B", "b"),
                     ("c", " "),
                     ("C", "c"),
-                    (".", " "),
+                    ("-", " "),
                 ]
             ]
         )
@@ -131,10 +156,10 @@ class PicroCrafterEngine:
             nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
         ).to(self.device)
         self.life_level_in_100th = torch.full(
-            (nb_agents,), self.life_level_max * 100, device=self.device
+            (nb_agents,), self.life_level_max * 100 + 99, device=self.device
         )
         self.accessible_object = torch.full(
-            (nb_agents,), self.token2id["a"], device=self.device
+            (nb_agents,), self.tile2id["a"], device=self.device
         )
 
     def create_mazes(self, nb, height, width, nb_walls):
@@ -177,15 +202,15 @@ class PicroCrafterEngine:
         u = torch.rand(q.size(), device=q.device) * (1 - q)
         r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
 
-        q *= self.token2id["#"]
+        q *= self.tile2id["#"]
         q[
             torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
-        ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :]
+        ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
 
         if margin > 0:
             r = m.new_full(
                 (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
-                self.token2id["+"],
+                self.tile2id["+"],
             )
             r[:, margin:-margin, margin:-margin] = m
             m = r
@@ -194,8 +219,14 @@ class PicroCrafterEngine:
     def nb_actions(self):
         return 5
 
-    def nb_view_tokens(self):
-        return len(self.tokens)
+    def action2str(self, n):
+        if n >= 0 and n < 5:
+            return "XNESW"[n]
+        else:
+            return "?"
+
+    def nb_view_tiles(self):
+        return len(self.tiles)
 
     def min_max_reward(self):
         return (
@@ -204,20 +235,20 @@ class PicroCrafterEngine:
         )
 
     def step(self, actions):
-        a = (self.worlds == self.token2id["@"]).nonzero()
-        self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "]
+        a = (self.worlds == self.tile2id["@"]).nonzero()
+        self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.tile2id[" "]
         s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
         b = a.clone()
         b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
         # position is empty
-        o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
+        o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.tile2id[" "]).long()
         # or it is the next accessible object
         q = (
             self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]]
         ).long()
         o = (o + q).clamp(max=1)[:, None]
         b = (1 - o) * a + o * b
-        self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
+        self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.tile2id["@"]
 
         qq = q
         q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
@@ -225,14 +256,14 @@ class PicroCrafterEngine:
 
         nb_hits = self.monster_moves()
 
-        alive_before = self.life_level_in_100th > 0
+        alive_before = self.life_level_in_100th > 99
         self.life_level_in_100th[alive_before] = (
             self.life_level_in_100th[alive_before]
             + self.life_level_gain_100th
             - nb_hits[alive_before] * 100
-        ).clamp(max=self.life_level_max * 100)
-        alive_after = self.life_level_in_100th > 0
-        self.worlds[torch.logical_not(alive_after)] = self.token2id["#"]
+        ).clamp(max=self.life_level_max * 100 + 99)
+        alive_after = self.life_level_in_100th > 99
+        self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"]
         reward = nb_hits * self.reward_per_hit
 
         for i in range(q.size(0)):
@@ -243,7 +274,8 @@ class PicroCrafterEngine:
                 ]
 
         reward = (
-            reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death
+            alive_after.long() * reward
+            + alive_before.long() * (1 - alive_after.long()) * self.reward_death
         )
         inventory = torch.tensor(
             [
@@ -254,7 +286,7 @@ class PicroCrafterEngine:
 
         self.life_level_in_100th = (
             self.life_level_in_100th
-            * (self.accessible_object != self.token2id["."]).long()
+            * (self.accessible_object != self.tile2id["-"]).long()
         )
 
         reward[torch.logical_not(alive_before)] = 0
@@ -262,7 +294,7 @@ class PicroCrafterEngine:
 
     def monster_moves(self):
         # Current positions of the monsters
-        m = (self.worlds == self.token2id["$"]).long().flatten(1)
+        m = (self.worlds == self.tile2id["$"]).long().flatten(1)
 
         # Total number of monsters
         n = m.sum(-1).max()
@@ -303,25 +335,25 @@ class PicroCrafterEngine:
 
         for n in range(p.size(1)):
             u = o[:, n].sort(dim=-1, descending=True).indices[:, :1]
-            q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n]
+            q = p[:, n] * (self.worlds.flatten(1) == self.tile2id[" "]) + o[:, n]
             r = (
                 (q * torch.rand(q.size(), device=q.device))
                 .sort(dim=-1, descending=True)
                 .indices[:, :1]
             )
-            self.worlds.flatten(1)[i, u] = self.token2id[" "]
-            self.worlds.flatten(1)[i, r] = self.token2id["$"]
+            self.worlds.flatten(1)[i, u] = self.tile2id[" "]
+            self.worlds.flatten(1)[i, r] = self.tile2id["$"]
 
         nb_hits = (
             (
                 conv2d(
-                    (self.worlds == self.token2id["$"]).float()[:, None],
+                    (self.worlds == self.tile2id["$"]).float()[:, None],
                     move_kernel,
                     padding=1,
                 )
                 .long()
                 .squeeze(1)
-                * (self.worlds == self.token2id["@"]).long()
+                * (self.worlds == self.tile2id["@"]).long()
             )
             .flatten(1)
             .sum(-1)
@@ -334,7 +366,7 @@ class PicroCrafterEngine:
             self.view_height - 2 * self.margin,
             self.view_width - 2 * self.margin,
         )
-        a = (self.worlds == self.token2id["@"]).nonzero()
+        a = (self.worlds == self.tile2id["@"]).nonzero()
         y = i_height * ((a[:, 1] - self.margin) // i_height)
         x = i_width * ((a[:, 2] - self.margin) // i_width)
         n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
@@ -347,58 +379,89 @@ class PicroCrafterEngine:
             + x[:, None, None]
         ).expand_as(n)
         v = self.worlds.new_full(
-            (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"]
+            (self.worlds.size(0), self.view_height + 1, self.view_width),
+            self.tile2id["#"],
         )
 
-        v[a[:, 0]] = self.worlds[n, i, j]
+        v[a[:, 0], : self.view_height] = self.worlds[n, i, j]
+
+        v[:, self.view_height] = self.tile2id["-"]
+        v[:, self.view_height, 0] = self.tile2id["0"] + (
+            self.life_level_in_100th // 100
+        ).clamp(min=0, max=self.life_level_max)
+        v[:, self.view_height, 1] = torch.tensor(
+            [
+                self.accessible_object_to_inventory[o.item()]
+                for o in self.accessible_object
+            ],
+            device=v.device,
+        )
 
         return v
 
+    def seq2tilepic(self, t, width):
+        def tile(n):
+            n = n.item()
+            if n in self.id2tile:
+                return self.id2tile[n]
+            else:
+                return "?"
+
+        if t.dim() == 2:
+            return [self.seq2tilepic(r, width) for r in t]
+
+        t = t.reshape(-1, width)
+
+        t = ["".join([tile(n) for n in r]) for r in t]
+
+        return t
+
     def print_worlds(
         self, src=None, comments=[], width=None, printer=print, ansi_term=False
     ):
         if src is None:
-            src = self.worlds
+            src = list(self.worlds)
 
-        if width is None:
-            width = src.size(2)
+        height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src])
 
-        def token(n):
+        def tile(n):
             n = n.item()
-            if n in self.id2token:
-                return self.id2token[n]
+            if n in self.id2tile:
+                return self.id2tile[n]
             else:
                 return "?"
 
-        for k in range(src.size(1)):
-            s = ["".join([token(n) for n in m[k]]) for m in src]
-            s = [r + " " * (width - len(r)) for r in s]
-            if ansi_term:
-
-                def colorize(x):
-                    for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
-                        (x, 36) for x in "aAbBcC"
-                    ]:
-                        x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m")
-                    return x
+        for k in range(height):
 
-                s = [colorize(x) for x in s]
-            printer(" | ".join(s))
+            def f(x):
+                if torch.is_tensor(x):
+                    if x.dim() == 0:
+                        x = str(x.item())
+                        return " " * len(x) if k < height - 1 else x
+                    else:
+                        s = "".join([tile(n) for n in x[k]])
+                        if ansi_term:
+                            for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
+                                (x, 36) for x in "aAbBcC"
+                            ]:
+                                s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
+                        return s
+                else:
+                    return " " * len(x) if k < height - 1 else x
 
-        s = [c + " " * (width - len(c)) for c in comments]
-        printer(" | ".join(s))
+            printer("|".join([f(x) for x in src]))
 
 
 ######################################################################
 
 if __name__ == "__main__":
-    import os, time
+    import os, time, sys
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
+    # nb_agents, nb_iter, display = 10000, 100, False
     # ansi_term = False
-    # nb_agents, nb_iter, display = 1000, 1000, False
-    nb_agents, nb_iter, display = 3, 10000, True
+    nb_agents, nb_iter, display = 4, 10000, True
     ansi_term = True
 
     start_time = time.perf_counter()
@@ -421,35 +484,48 @@ if __name__ == "__main__":
 
     start_time = time.perf_counter()
 
+    if ansi_term:
+        coloring = add_ansi_coloring
+    else:
+        coloring = lambda x: x
+
     stop = 0
     for k in range(nb_iter):
+        if display:
+            if ansi_term:
+                to_print = "\u001bc"
+                # print("\u001b[2J")
+            else:
+                to_print = ""
+                os.system("clear")
+
+            l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width)
+
+            to_print += coloring(fusion_multi_lines(l)) + "\n\n"
+
+        views = engine.views()
         action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
-        rewards, inventories, life_levels = engine.step(
-            torch.randint(engine.nb_actions(), (nb_agents,), device=device)
-        )
+
+        rewards, inventories, life_levels = engine.step(action)
 
         if display:
-            os.system("clear")
-            engine.print_worlds(
-                ansi_term=ansi_term,
-            )
-            print()
-            engine.print_worlds(
-                src=engine.views(),
-                comments=[
-                    f"L{p}I{engine.id2token[s.item()]}R{r}"
-                    for p, s, r in zip(life_levels, inventories, rewards)
-                ],
-                width=engine.world_width,
-                ansi_term=ansi_term,
+            l = engine.seq2tilepic(views.flatten(1), engine.view_width)
+            l = [
+                v + [f"{engine.action2str(a.item())}/{r: 3d}"]
+                for (v, a, r) in zip(l, action, rewards)
+            ]
+
+            to_print += (
+                coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
             )
+
+            print(to_print)
+            sys.stdout.flush()
             time.sleep(0.25)
 
         if (life_levels > 0).long().sum() == 0:
             stop += 1
-            if stop == 2:
+            if stop == 10:
                 break
 
-    print(
-        f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"
-    )
+    print(f"timing {(nb_agents*k)/(time.perf_counter() - start_time)} iteration per s")