From: François Fleuret Date: Sun, 24 Mar 2024 10:36:44 +0000 (+0100) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=picoclvr.git;a=commitdiff_plain;h=9bf6dde1249fde5ba0ca2688599d8dd324d8c503 Update. --- diff --git a/escape.py b/escape.py index 93e3052..fc4fbbc 100755 --- a/escape.py +++ b/escape.py @@ -111,13 +111,12 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): if lookahead_delta is not None: r = rewards - print(f"{r.size()=} {lookahead_delta=}") u = F.pad(r, (0, lookahead_delta - 1)).as_strided( (r.size(0), r.size(1), lookahead_delta), (r.size(1) + lookahead_delta - 1, 1, 1), ) - a = u.min(dim=-1).values - b = u.max(dim=-1).values + a = u[:, :, 1:].min(dim=-1).values + b = u[:, :, 1:].max(dim=-1).values s = (a < 0).long() * a + (a >= 0).long() * b lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code