Update.
[picoclvr.git] / greed.py
index 636c13b..1025d7c 100755 (executable)
--- a/greed.py
+++ b/greed.py
@@ -11,6 +11,11 @@ from torch.nn import functional as F
 
 ######################################################################
 
+REWARD_PLUS = 1
+REWARD_NONE = 0
+REWARD_MINUS = -1
+REWARD_UNKNOWN = 2
+
 
 class GreedWorld:
     def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
@@ -36,11 +41,11 @@ class GreedWorld:
         )
 
         self.state_len = self.height * self.width
-        self.index_states = 0
-        self.index_reward = self.state_len
-        self.index_lookahead_reward = self.state_len + 1
+        self.index_lookahead_reward = 0
+        self.index_states = 1
+        self.index_reward = self.state_len + 1
         self.index_action = self.state_len + 2
-        self.it_len = self.state_len + 3  # lookahead_reward / state / action / reward
+        self.it_len = self.state_len + 3  # lookahead_reward / state / reward / action
 
     def state2code(self, r):
         return r + self.first_states_code
@@ -179,9 +184,9 @@ class GreedWorld:
 
         return torch.cat(
             [
+                self.lookahead_reward2code(s[:, :, None]),
                 self.state2code(states.flatten(2)),
                 self.reward2code(rewards[:, :, None]),
-                self.lookahead_reward2code(s[:, :, None]),
                 self.action2code(actions[:, :, None]),
             ],
             dim=2,