Minor syntactic change.
[pytorch.git] / ddpol.py
index 9a1bbc9..6812fdf 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -34,7 +34,7 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
 
-    return torch.lstsq(B, M).solution.view(-1)[:D+1]
+    return torch.lstsq(B, M).solution[:D+1, 0]
 
 ######################################################################