X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=42d912674641db0fec26a48ce5687b7040cba5cb;hb=291c38d093894d46fba6eb45f82e5b65a2a1cb8b;hp=234e78060a3274026f50ab400e434af8e3528350;hpb=40d0010b6b76304e340ae734cb9814e714b691cc;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 234e780..42d9126 100755 --- a/tasks.py +++ b/tasks.py @@ -1285,7 +1285,7 @@ class RPL(Task): if save_attention_image is not None: input = self.test_input[:1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 - input = input[:, :last] + input = input[:, :last].to(self.device) with torch.autograd.no_grad(): t = model.training