From 9141338f022ff991ac91e448eda0fd1cb401fd84 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 24 Mar 2024 16:13:37 +0100 Subject: [PATCH] Update. --- tasks.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/tasks.py b/tasks.py index 2f3db6a..02c44bb 100755 --- a/tasks.py +++ b/tasks.py @@ -1909,6 +1909,54 @@ class Escape(Task): def vocabulary_size(self): return self.nb_codes + def thinking_autoregression( + self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 + ): + result = self.test_input[:100].clone() + t = torch.arange(result.size(1), device=result.device) + itl = self.height * self.width + 3 + + def ar(): + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + for u in range(itl, result.size(1) - itl + 1, itl): + print(f"{itl=} {u=} {result.size(1)=}") + result[:, u - 1] = (-1) + 1 + escape.first_lookahead_rewards_code + ar_mask = (t >= u).long() * (t < u + self.height * self.width).long() + ar_mask = ar_mask[None, :] + ar_mask = ar_mask.expand_as(result) + result *= 1 - ar_mask + ar() + result[:, u - 1] = (1) + 1 + escape.first_lookahead_rewards_code + ar_mask = (t >= self.height * self.width).long() * ( + t < self.height * self.width + 2 + ).long() + ar_mask = ar_mask[None, :] + ar_mask = ar_mask.expand_as(result) + result *= 1 - ar_mask + ar() + + # Saving the generated sequences + + s, a, r, lr = escape.seq2episodes( + result, self.height, self.width, lookahead=True + ) + str = escape.episodes2str( + s, a, r, lookahead_rewards=lr, unicode=True, ansi_colors=True + ) + + filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt") + with open(filename, "w") as f: + f.write(str) + logger(f"wrote {filename}") + def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 ): @@ -1932,7 +1980,7 @@ class Escape(Task): ar_mask = ( torch.arange(result.size(1), device=result.device) - > self.height * self.width + 2 + >= self.height * self.width + 3 ).long()[None, :] ar_mask = ar_mask.expand_as(result) result *= 1 - ar_mask # paraaaaanoiaaaaaaa @@ -1960,5 +2008,9 @@ class Escape(Task): f.write(str) logger(f"wrote {filename}") + self.thinking_autoregression( + n_epoch, model, result_dir, logger, deterministic_synthesis, nmax + ) + ###################################################################### -- 2.20.1