Update.
[picoclvr.git] / tasks.py
index 234e780..42d9126 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1285,7 +1285,7 @@ class RPL(Task):
         if save_attention_image is not None:
             input = self.test_input[:1].clone()
             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
-            input = input[:, :last]
+            input = input[:, :last].to(self.device)
 
             with torch.autograd.no_grad():
                 t = model.training