X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=escape.py;h=6f4af359730f792dcdb22b518c2c3a4b704434c8;hb=2be22c9825d8aebe8d184e9501355a31318abf2b;hp=f51863b0c42e9c4c0d38fa334296c6ebfca9ed04;hpb=621231cc5bb94f983c556a1b450b66067bec4165;p=picoclvr.git diff --git a/escape.py b/escape.py index f51863b..6f4af35 100755 --- a/escape.py +++ b/escape.py @@ -25,6 +25,33 @@ nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes ###################################################################### +def action2code(r): + return first_actions_code + r + + +def code2action(r): + return r - first_actions_code + + +def reward2code(r): + return first_rewards_code + r + 1 + + +def code2reward(r): + return r - first_rewards_code - 1 + + +def lookahead_reward2code(r): + return first_lookahead_rewards_code + r + 1 + + +def code2lookahead_reward(r): + return r - first_lookahead_rewards_code - 1 + + +###################################################################### + + def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3): rnd = torch.rand(nb, height, width) rnd[:, 0, :] = 0 @@ -111,17 +138,6 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): actions = actions[:, :, None] + first_actions_code if lookahead_delta is not None: - # r = rewards - # u = F.pad(r, (0, lookahead_delta - 1)).as_strided( - # (r.size(0), r.size(1), lookahead_delta), - # (r.size(1) + lookahead_delta - 1, 1, 1), - # ) - # a = u[:, :, 1:].min(dim=-1).values - # b = u[:, :, 1:].max(dim=-1).values - # s = (a < 0).long() * a + (a >= 0).long() * b - # lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code - - # a[n,t]=min_s>t r[n,s] a = rewards.new_zeros(rewards.size()) b = rewards.new_zeros(rewards.size()) for t in range(a.size(1) - 1):