projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
41164ce
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 27 Mar 2024 20:35:32 +0000
(21:35 +0100)
committer
François Fleuret
<francois@fleuret.org>
Wed, 27 Mar 2024 20:35:32 +0000
(21:35 +0100)
greed.py
patch
|
blob
|
history
diff --git
a/greed.py
b/greed.py
index
47cfb40
..
636c13b
100755
(executable)
--- a/
greed.py
+++ b/
greed.py
@@
-172,7
+172,7
@@
class GreedWorld:
def episodes2seq(self, states, actions, rewards):
neg = rewards.new_zeros(rewards.size())
pos = rewards.new_zeros(rewards.size())
def episodes2seq(self, states, actions, rewards):
neg = rewards.new_zeros(rewards.size())
pos = rewards.new_zeros(rewards.size())
- for t in range(neg.size(1)
- 1
):
+ for t in range(neg.size(1)):
neg[:, t] = rewards[:, t:].min(dim=-1).values
pos[:, t] = rewards[:, t:].max(dim=-1).values
s = (neg < 0).long() * neg + (neg >= 0).long() * pos
neg[:, t] = rewards[:, t:].min(dim=-1).values
pos[:, t] = rewards[:, t:].max(dim=-1).values
s = (neg < 0).long() * neg + (neg >= 0).long() * pos
@@
-189,11
+189,15
@@
class GreedWorld:
def seq2episodes(self, seq):
seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
def seq2episodes(self, seq):
seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
- lookahead_rewards = self.code2lookahead_reward(seq[:, :, 0])
- states = self.code2state(seq[:, :, 1 : self.height * self.width + 1])
+ lookahead_rewards = self.code2lookahead_reward(
+ seq[:, :, self.index_lookahead_reward]
+ )
+ states = self.code2state(
+ seq[:, :, self.index_states : self.height * self.width + self.index_states]
+ )
states = states.reshape(states.size(0), states.size(1), self.height, self.width)
states = states.reshape(states.size(0), states.size(1), self.height, self.width)
- actions = self.code2action(seq[:, :, self.
height * self.width + 1
])
- rewards = self.code2reward(seq[:, :, self.
height * self.width + 2
])
+ actions = self.code2action(seq[:, :, self.
index_action
])
+ rewards = self.code2reward(seq[:, :, self.
index_reward
])
return lookahead_rewards, states, actions, rewards
def seq2str(self, seq):
return lookahead_rewards, states, actions, rewards
def seq2str(self, seq):