Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 23 Jul 2023 09:53:52 +0000 (11:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 23 Jul 2023 09:53:52 +0000 (11:53 +0200)
graph.py
tasks.py

index bd80187..2c7caf8 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -59,7 +59,7 @@ def save_attention_image(
 
     ctx.set_line_width(0.25)
     for d in range(len(attention_matrices)):
-        at = attention_matrices[d]
+        at = attention_matrices[d].to("cpu")
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
         at = at.flatten()
index 234e780..42d9126 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1285,7 +1285,7 @@ class RPL(Task):
         if save_attention_image is not None:
             input = self.test_input[:1].clone()
             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
-            input = input[:, :last]
+            input = input[:, :last].to(self.device)
 
             with torch.autograd.no_grad():
                 t = model.training