From 19ec7f3e4030ddece2647983dcf1bed5eb0d9544 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 25 Mar 2024 07:22:23 +0100 Subject: [PATCH] Update. --- escape.py | 36 ++++++++++++++++++------------------ tasks.py | 10 ++++++++-- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/escape.py b/escape.py index 43843f0..f51863b 100755 --- a/escape.py +++ b/escape.py @@ -94,7 +94,7 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3): ) hit = (hit > 0).long() - assert hit.min() == 0 and hit.max() <= 1 + # assert hit.min() == 0 and hit.max() <= 1 rewards[:, t + 1] = -hit + (1 - hit) * agent[:, t + 1, -1, -1] @@ -133,27 +133,27 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): r = rewards[:, :, None] rewards = (r + 1) + first_rewards_code - assert ( - states.min() >= first_state_code - and states.max() < first_state_code + nb_state_codes - ) - assert ( - actions.min() >= first_actions_code - and actions.max() < first_actions_code + nb_actions_codes - ) - assert ( - rewards.min() >= first_rewards_code - and rewards.max() < first_rewards_code + nb_rewards_codes - ) + # assert ( + # states.min() >= first_state_code + # and states.max() < first_state_code + nb_state_codes + # ) + # assert ( + # actions.min() >= first_actions_code + # and actions.max() < first_actions_code + nb_actions_codes + # ) + # assert ( + # rewards.min() >= first_rewards_code + # and rewards.max() < first_rewards_code + nb_rewards_codes + # ) if lookahead_delta is None: return torch.cat([states, actions, rewards], dim=2).flatten(1) else: - assert ( - lookahead_rewards.min() >= first_lookahead_rewards_code - and lookahead_rewards.max() - < first_lookahead_rewards_code + nb_lookahead_rewards_codes - ) + # assert ( + # lookahead_rewards.min() >= first_lookahead_rewards_code + # and lookahead_rewards.max() + # < first_lookahead_rewards_code + nb_lookahead_rewards_codes + # ) return torch.cat([states, actions, rewards, lookahead_rewards], dim=2).flatten( 1 ) diff --git a/tasks.py b/tasks.py index fddcaff..5153836 100755 --- a/tasks.py +++ b/tasks.py @@ -1938,8 +1938,13 @@ class Escape(Task): range(it_len, result.size(1) - it_len + 1, it_len), desc="thinking" ): # Put the lookahead reward to either 0 or -1 for the - # current iteration, sample the next state - s = -(torch.rand(result.size(0), device=result.device) < 0.2).long() + # current iteration, with a proba that depends with the + # sequence index, so that we have diverse examples, sample + # the next state + s = -( + torch.rand(result.size(0), device=result.device) + <= torch.linspace(0, 1, result.size(0), device=result.device) + ).long() result[:, u - 1] = s + 1 + escape.first_lookahead_rewards_code ar_mask = (t >= u).long() * (t < u + state_len).long() ar(result, ar_mask) @@ -1956,6 +1961,7 @@ class Escape(Task): # Extract the rewards r = result[:, range(v + state_len + 1 + it_len, u + it_len - 1, it_len)] r = r - escape.first_rewards_code - 1 + r = r.clamp(min=-1, max=1) # the reward is predicted hence can be weird a = r.min(dim=1).values b = r.max(dim=1).values s = (a < 0).long() * a + (a >= 0).long() * b -- 2.20.1