Update.
[picoclvr.git] / escape.py
index a3d8c85..7596bea 100755 (executable)
--- a/escape.py
+++ b/escape.py
@@ -14,7 +14,7 @@ from torch.nn import functional as F
 nb_states_codes = 5
 nb_actions_codes = 5
 nb_rewards_codes = 3
-nb_lookahead_rewards_codes = 3
+nb_lookahead_rewards_codes = 4  # stands for -1, 0, +1, and UNKNOWN
 
 first_states_code = 0
 first_actions_code = first_states_code + nb_states_codes
@@ -50,6 +50,7 @@ def code2reward(r):
 
 
 def lookahead_reward2code(r):
+    # -1, 0, +1 or 2 for UNKNOWN
     return r + 1 + first_lookahead_rewards_code
 
 
@@ -60,7 +61,7 @@ def code2lookahead_reward(r):
 ######################################################################
 
 
-def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=3):
+def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
     rnd = torch.rand(nb, height, width)
     rnd[:, 0, :] = 0
     rnd[:, -1, :] = 0
@@ -195,7 +196,7 @@ def seq2str(seq):
             t >= first_lookahead_rewards_code
             and t < first_lookahead_rewards_code + nb_lookahead_rewards_codes
         ):
-            return "n.p"[t - first_lookahead_rewards_code]
+            return "n.pU"[t - first_lookahead_rewards_code]
         else:
             return "?"