Update.
[picoclvr.git] / graph.py
index bd80187..2c7caf8 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -59,7 +59,7 @@ def save_attention_image(
 
     ctx.set_line_width(0.25)
     for d in range(len(attention_matrices)):
-        at = attention_matrices[d]
+        at = attention_matrices[d].to("cpu")
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
         at = at.flatten()