Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 4 Nov 2023 11:12:30 +0000 (12:12 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 4 Nov 2023 11:12:30 +0000 (12:12 +0100)
picocrafter.py

index ef861bc..e303554 100755 (executable)
@@ -22,8 +22,9 @@
 # it can initialize ~20k environments per second and run ~40k
 # iterations.
 #
-# The agent "@" moves in a maze-like grid with random walls "#". There
-# are five actions: move NESW or do not move.
+# The environment is a rectangular area with walls "#" dispatched
+# randomly. The agent "@" can perform five actions: move NESW or do
+# not move.
 #
 # There are monsters "$" moving randomly. The agent gets hit by every
 # monster present in one of the 4 direct neighborhoods at the end of
@@ -40,8 +41,8 @@
 # which case the key is removed from the environment and the agent now
 # carries it, and can move to free spaces or the "A". When it moves to
 # the "A", it gets a reward, loses the "a", the "A" is removed from
-# the environment, but can now move to the "b", etc. Rewards are 1 for
-# "A" and "B" and 10 for "C".
+# the environment, but the agent can now move to the "b", etc. Rewards
+# are 1 for "A" and "B" and 10 for "C".
 
 ######################################################################
 
@@ -52,22 +53,36 @@ from torch.nn.functional import conv2d
 ######################################################################
 
 
-def add_ansi_coloring(s):
+def to_ansi(s):
     if type(s) is list:
-        return [add_ansi_coloring(x) for x in s]
+        return [to_ansi(x) for x in s]
 
-    for u, c in [("#", 40), ("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
+    for u, c in [("$", 31), ("@", 32)] + [(x, 36) for x in "aAbBcC"]:
         s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
 
     return s
 
 
+def to_unicode(s):
+    if type(s) is list:
+        return [to_unicode(x) for x in s]
+
+    for u, c in [("#", "█"), ("+", "░"), ("|", "│")]:
+        s = s.replace(u, c)
+
+    return s
+
+
 def fusion_multi_lines(l, width_min=0):
     l = [x if type(x) is list else [str(x)] for x in l]
 
+    def center(r, w):
+        k = w - len(r)
+        return " " * (k // 2) + r + " " * (k - k // 2)
+
     def f(o, h):
         w = max(width_min, max([len(r) for r in o]))
-        return [" " * w] * (h - len(o)) + [r + " " * (w - len(r)) for r in o]
+        return [" " * w] * (h - len(o)) + [center(r, w) for r in o]
 
     h = max([len(x) for x in l])
     l = [f(o, h) for o in l]
@@ -416,41 +431,6 @@ class PicroCrafterEngine:
 
         return t
 
-    def print_worlds(
-        self, src=None, comments=[], width=None, printer=print, ansi_term=False
-    ):
-        if src is None:
-            src = list(self.worlds)
-
-        height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src])
-
-        def tile(n):
-            n = n.item()
-            if n in self.id2tile:
-                return self.id2tile[n]
-            else:
-                return "?"
-
-        for k in range(height):
-
-            def f(x):
-                if torch.is_tensor(x):
-                    if x.dim() == 0:
-                        x = str(x.item())
-                        return " " * len(x) if k < height - 1 else x
-                    else:
-                        s = "".join([tile(n) for n in x[k]])
-                        if ansi_term:
-                            for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
-                                (x, 36) for x in "aAbBcC"
-                            ]:
-                                s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
-                        return s
-                else:
-                    return " " * len(x) if k < height - 1 else x
-
-            printer("|".join([f(x) for x in src]))
-
 
 ######################################################################
 
@@ -459,19 +439,23 @@ if __name__ == "__main__":
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    # nb_agents, nb_iter, display = 10000, 100, False
+    # char_conv = lambda x: x
+    char_conv = to_unicode
+
+    # nb_agents, nb_iter, display = 1000, 1000, False
     # ansi_term = False
+
     nb_agents, nb_iter, display = 4, 10000, True
     ansi_term = True
 
+    if ansi_term:
+        char_conv = lambda x: to_ansi(to_unicode(x))
+
     start_time = time.perf_counter()
     engine = PicroCrafterEngine(
         world_height=27,
         world_width=27,
         nb_walls=35,
-        # world_height=15,
-        # world_width=15,
-        # nb_walls=0,
         view_height=9,
         view_width=9,
         margin=4,
@@ -484,11 +468,6 @@ if __name__ == "__main__":
 
     start_time = time.perf_counter()
 
-    if ansi_term:
-        coloring = add_ansi_coloring
-    else:
-        coloring = lambda x: x
-
     stop = 0
     for k in range(nb_iter):
         if display:
@@ -501,7 +480,7 @@ if __name__ == "__main__":
 
             l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width)
 
-            to_print += coloring(fusion_multi_lines(l)) + "\n\n"
+            to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
 
         views = engine.views()
         action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
@@ -516,7 +495,7 @@ if __name__ == "__main__":
             ]
 
             to_print += (
-                coloring(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
+                char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n"
             )
 
             print(to_print)