Update.
[pytorch.git] / picocrafter.py
index 33a00c1..31ba1e4 100755 (executable)
 # 5pt.
 #
 # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A",
-# "B", "C"). They keys can only be used in sequence: initially the
-# agent can move only to free spaces, or to the "a", in which case it
-# now carries it, and can move to free spaces or the "A". When it
-# moves to the "A", it gets a reward and loses the "a", but can now
-# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C".
+# "B", "C"). The keys and vault can only be used in sequence:
+# initially the agent can move only to free spaces, or to the "a", in
+# which case the key is removed from the environment and the agent now
+# carries it, and can move to free spaces or the "A". When it moves to
+# the "A", it gets a reward, loses the "a", the "A" is removed from
+# the environment, but can now move to the "b", etc. Rewards are 1 for
+# "A" and "B" and 10 for "C".
 
 ######################################################################
 
@@ -90,6 +92,7 @@ class PicroCrafterEngine:
                     ("b", "B"),
                     ("B", "c"),
                     ("c", "C"),
+                    ("C", " "),
                 ]
             ]
         )
@@ -245,6 +248,11 @@ class PicroCrafterEngine:
             ]
         )
 
+        self.life_level_in_100th = (
+            self.life_level_in_100th
+            * (self.accessible_object != self.token2id[" "]).long()
+        )
+
         reward[torch.logical_not(alive_before)] = 0
         return reward, inventory, self.life_level_in_100th // 100
 
@@ -387,7 +395,7 @@ if __name__ == "__main__":
     ansi_term = False
     # nb_agents, nb_iter, display = 1000, 100, False
     nb_agents, nb_iter, display = 3, 10000, True
-    ansi_term = True
+    ansi_term = True
 
     start_time = time.perf_counter()
     engine = PicroCrafterEngine(
@@ -427,7 +435,7 @@ if __name__ == "__main__":
                 width=engine.world_width,
                 ansi_term=ansi_term,
             )
-            time.sleep(0.5)
+            time.sleep(0.25)
 
         if (life_levels > 0).long().sum() == 0:
             break