-grid = torch.linspace(-1.2,1.2,sg)
-grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2)
+grid = torch.linspace(-1.2, 1.2, sg)
+grid = torch.cat(
+ (grid[:, None, None].expand(sg, sg, 1), grid[None, :, None].expand(sg, sg, 1)), -1
+).reshape(-1, 2)