Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 6 Nov 2023 07:06:33 +0000 (08:06 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 6 Nov 2023 07:06:33 +0000 (08:06 +0100)
picocrafter.py

index e303554..36088ac 100755 (executable)
@@ -96,19 +96,19 @@ class PicroCrafterEngine:
         world_height=27,
         world_width=27,
         nb_walls=27,
-        margin=2,
+        world_margin=2,
         view_height=5,
         view_width=5,
         device=torch.device("cpu"),
     ):
-        assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0
-        assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0
+        assert (world_height - 2 * world_margin) % (view_height - 2 * world_margin) == 0
+        assert (world_width - 2 * world_margin) % (view_width - 2 * world_margin) == 0
 
         self.device = device
 
         self.world_height = world_height
         self.world_width = world_width
-        self.margin = margin
+        self.world_margin = world_margin
         self.view_height = view_height
         self.view_width = view_width
         self.nb_walls = nb_walls
@@ -168,7 +168,11 @@ class PicroCrafterEngine:
 
     def reset(self, nb_agents):
         self.worlds = self.create_worlds(
-            nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
+            nb_agents,
+            self.world_height,
+            self.world_width,
+            self.nb_walls,
+            self.world_margin,
         ).to(self.device)
         self.life_level_in_100th = torch.full(
             (nb_agents,), self.life_level_max * 100 + 99, device=self.device
@@ -209,9 +213,11 @@ class PicroCrafterEngine:
 
         return m
 
-    def create_worlds(self, nb, height, width, nb_walls, margin=2):
-        margin -= 1  # The maze adds a wall all around
-        m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
+    def create_worlds(self, nb, height, width, nb_walls, world_margin=2):
+        world_margin -= 1  # The maze adds a wall all around
+        m = self.create_mazes(
+            nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls
+        )
         q = m.flatten(1)
         z = "@aAbBcC$$$$$"  # What to add to the maze
         u = torch.rand(q.size(), device=q.device) * (1 - q)
@@ -222,12 +228,12 @@ class PicroCrafterEngine:
             torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
         ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
 
-        if margin > 0:
+        if world_margin > 0:
             r = m.new_full(
-                (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2),
+                (m.size(0), m.size(1) + world_margin * 2, m.size(2) + world_margin * 2),
                 self.tile2id["+"],
             )
-            r[:, margin:-margin, margin:-margin] = m
+            r[:, world_margin:-world_margin, world_margin:-world_margin] = m
             m = r
         return m
 
@@ -378,12 +384,12 @@ class PicroCrafterEngine:
 
     def views(self):
         i_height, i_width = (
-            self.view_height - 2 * self.margin,
-            self.view_width - 2 * self.margin,
+            self.view_height - 2 * self.world_margin,
+            self.view_width - 2 * self.world_margin,
         )
         a = (self.worlds == self.tile2id["@"]).nonzero()
-        y = i_height * ((a[:, 1] - self.margin) // i_height)
-        x = i_width * ((a[:, 2] - self.margin) // i_width)
+        y = i_height * ((a[:, 1] - self.world_margin) // i_height)
+        x = i_width * ((a[:, 2] - self.world_margin) // i_width)
         n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
         i = (
             torch.arange(self.view_height, device=a.device)[None, :, None]
@@ -414,7 +420,7 @@ class PicroCrafterEngine:
 
         return v
 
-    def seq2tilepic(self, t, width):
+    def seq2tiles(self, t, width=None):
         def tile(n):
             n = n.item()
             if n in self.id2tile:
@@ -423,7 +429,10 @@ class PicroCrafterEngine:
                 return "?"
 
         if t.dim() == 2:
-            return [self.seq2tilepic(r, width) for r in t]
+            return [self.seq2tiles(r, width) for r in t]
+
+        if width is None:
+            width = self.view_width
 
         t = t.reshape(-1, width)
 
@@ -458,7 +467,7 @@ if __name__ == "__main__":
         nb_walls=35,
         view_height=9,
         view_width=9,
-        margin=4,
+        world_margin=4,
         device=device,
     )
 
@@ -478,7 +487,7 @@ if __name__ == "__main__":
                 to_print = ""
                 os.system("clear")
 
-            l = engine.seq2tilepic(engine.worlds.flatten(1), width=engine.world_width)
+            l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width)
 
             to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
 
@@ -488,7 +497,7 @@ if __name__ == "__main__":
         rewards, inventories, life_levels = engine.step(action)
 
         if display:
-            l = engine.seq2tilepic(views.flatten(1), engine.view_width)
+            l = engine.seq2tiles(views.flatten(1))
             l = [
                 v + [f"{engine.action2str(a.item())}/{r: 3d}"]
                 for (v, a, r) in zip(l, action, rewards)