X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=ddpol.py;h=6812fdff152e10290ae8608b4a0672a91b3eec37;hp=9a1bbc955df3611617194d4efce91f02eebc4d36;hb=546c9d6c72acd590018106a06afc5f1efc1be1a0;hpb=329837be2f41d7839046cc5ab0825b824825bf84 diff --git a/ddpol.py b/ddpol.py index 9a1bbc9..6812fdf 100755 --- a/ddpol.py +++ b/ddpol.py @@ -34,7 +34,7 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - return torch.lstsq(B, M).solution.view(-1)[:D+1] + return torch.lstsq(B, M).solution[:D+1, 0] ######################################################################