Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 22 Oct 2023 17:53:59 +0000 (19:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 22 Oct 2023 17:53:59 +0000 (19:53 +0200)
problems.py

index 4059856..b8fcdb3 100755 (executable)
@@ -313,8 +313,8 @@ class ProblemMixing(Problem):
         return y
 
     def start_error(self, x):
-        i = torch.arange(self.height).reshape(1,-1,1).expand_as(x)
-        j = torch.arange(self.width).reshape(1,1,-1).expand_as(x)
+        i = torch.arange(self.height, device=x.device).reshape(1,-1,1).expand_as(x)
+        j = torch.arange(self.width, device=x.device).reshape(1,1,-1).expand_as(x)
 
         ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1)
         rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)