X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;fp=tasks.py;h=3ef64d7002fafb6e57d8e2a7933795341535bdae;hb=c3581ba868cd30cb45fbe2f97b80ddbc1fc26bbb;hp=324376df60319e9549ae431c5d43dd04f1a29ed9;hpb=232299b8af7e66a02e64bb2e47b525e2f50b099d;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 324376d..3ef64d7 100755 --- a/tasks.py +++ b/tasks.py @@ -1944,7 +1944,7 @@ class Greed(Task): progress_bar_desc=None, ) warnings.warn("keeping thinking snapshots", RuntimeWarning) - snapshots.append(result[:10].detach().clone()) + snapshots.append(result[:100].detach().clone()) # Generate iteration after iteration @@ -1986,11 +1986,11 @@ class Greed(Task): # Set the lookahead_reward to UNKNOWN for the next iterations result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN) + ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: - for n in range(10): + for n in range(snapshots[0].size(0)): for s in snapshots: lr, s, a, r = self.world.seq2episodes( s[n : n + 1],