From 291c38d093894d46fba6eb45f82e5b65a2a1cb8b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 23 Jul 2023 11:53:52 +0200 Subject: [PATCH] Update. --- graph.py | 2 +- tasks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/graph.py b/graph.py index bd80187..2c7caf8 100755 --- a/graph.py +++ b/graph.py @@ -59,7 +59,7 @@ def save_attention_image( ctx.set_line_width(0.25) for d in range(len(attention_matrices)): - at = attention_matrices[d] + at = attention_matrices[d].to("cpu") ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten() diff --git a/tasks.py b/tasks.py index 234e780..42d9126 100755 --- a/tasks.py +++ b/tasks.py @@ -1285,7 +1285,7 @@ class RPL(Task): if save_attention_image is not None: input = self.test_input[:1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 - input = input[:, :last] + input = input[:, :last].to(self.device) with torch.autograd.no_grad(): t = model.training -- 2.20.1