X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picocrafter.py;h=31ba1e4039e79f192aea1e9b544560bd516fb551;hb=7cf92d14892ccce7c5a1eaa38c0d6b8fff03e751;hp=33a00c126e5309924933719fcf0a32bdcdffaf2a;hpb=a2ccdd2f5e9fb3e7ed52492729b880f815ddfbcb;p=pytorch.git diff --git a/picocrafter.py b/picocrafter.py index 33a00c1..31ba1e4 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -35,11 +35,13 @@ # 5pt. # # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A", -# "B", "C"). They keys can only be used in sequence: initially the -# agent can move only to free spaces, or to the "a", in which case it -# now carries it, and can move to free spaces or the "A". When it -# moves to the "A", it gets a reward and loses the "a", but can now -# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C". +# "B", "C"). The keys and vault can only be used in sequence: +# initially the agent can move only to free spaces, or to the "a", in +# which case the key is removed from the environment and the agent now +# carries it, and can move to free spaces or the "A". When it moves to +# the "A", it gets a reward, loses the "a", the "A" is removed from +# the environment, but can now move to the "b", etc. Rewards are 1 for +# "A" and "B" and 10 for "C". ###################################################################### @@ -90,6 +92,7 @@ class PicroCrafterEngine: ("b", "B"), ("B", "c"), ("c", "C"), + ("C", " "), ] ] ) @@ -245,6 +248,11 @@ class PicroCrafterEngine: ] ) + self.life_level_in_100th = ( + self.life_level_in_100th + * (self.accessible_object != self.token2id[" "]).long() + ) + reward[torch.logical_not(alive_before)] = 0 return reward, inventory, self.life_level_in_100th // 100 @@ -387,7 +395,7 @@ if __name__ == "__main__": ansi_term = False # nb_agents, nb_iter, display = 1000, 100, False nb_agents, nb_iter, display = 3, 10000, True - ansi_term = True + # ansi_term = True start_time = time.perf_counter() engine = PicroCrafterEngine( @@ -427,7 +435,7 @@ if __name__ == "__main__": width=engine.world_width, ansi_term=ansi_term, ) - time.sleep(0.5) + time.sleep(0.25) if (life_levels > 0).long().sum() == 0: break