Update.
[pytorch.git] / ddpol.py
index 6ace38d..645f47c 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -50,15 +50,18 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
         r = q.view(-1,  1)
         beta = x.new_zeros(D + 1, D + 1)
         beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
-        l, U = beta.eig(eigenvectors = True)
-        Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values
+        W = torch.linalg.eig(beta)
+        l, U = W.eigenvalues.real, W.eigenvectors.real
+        Q = U @ torch.diag(l.clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values
         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
 
-    return torch.lstsq(B, M).solution[:D+1, 0]
+    return torch.linalg.lstsq(M, B).solution[:D+1]
 
 ######################################################################
 
+# The "ground truth"
+
 def phi(x):
     return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
 
@@ -97,7 +100,7 @@ ax.set_ylabel('MSE', labelpad = 10)
 ax.axvline(x = args.nb_train_samples - 1,
            color = 'gray', linewidth = 0.5, linestyle = '--')
 
-ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
+ax.text(args.nb_train_samples - 1.2, 1e-4, 'nb. params = nb. samples',
         fontsize = 10, color = 'gray',
         rotation = 90, rotation_mode='anchor')
 
@@ -132,6 +135,8 @@ for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2)
     ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
     ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
 
+ax.legend(frameon = False)
+
 fig.savefig('dd-multi-mse.pdf', bbox_inches='tight')
 
 plt.close(fig)