Update.
[picoclvr.git] / graph.py
index 08f1170..2c7caf8 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -20,8 +20,8 @@ def save_attention_image(
     attention_matrices,  # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk
     # do not draw links with a lesser attention
     min_link_attention=0,
-    # draw only the strongest links necessary to reache
-    # min_total_attention
+    # draw only the strongest links necessary so that their summed
+    # attention is above min_total_attention
     min_total_attention=None,
     # draw only the top k links
     k_top=None,
@@ -59,7 +59,7 @@ def save_attention_image(
 
     ctx.set_line_width(0.25)
     for d in range(len(attention_matrices)):
-        at = attention_matrices[d]
+        at = attention_matrices[d].to("cpu")
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
         at = at.flatten()