X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=ddpol.py;h=645f47cfee373e9cb51d12cc5637021469c0d549;hb=47525ec795faca1ab72aee13956a553d070c81b7;hp=9d14d2ad7d5ad75b979e16477e5047f9e690276e;hpb=fbc0246e209dd03b9c9865904039b565d7fa9f96;p=pytorch.git diff --git a/ddpol.py b/ddpol.py index 9d14d2a..645f47c 100755 --- a/ddpol.py +++ b/ddpol.py @@ -50,15 +50,18 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): r = q.view(-1, 1) beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) - l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values + W = torch.linalg.eig(beta) + l, U = W.eigenvalues.real, W.eigenvectors.real + Q = U @ torch.diag(l.clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - return torch.lstsq(B, M).solution[:D+1, 0] + return torch.linalg.lstsq(M, B).solution[:D+1] ###################################################################### +# The "ground truth" + def phi(x): return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1 @@ -97,7 +100,7 @@ ax.set_ylabel('MSE', labelpad = 10) ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5, linestyle = '--') -ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples', +ax.text(args.nb_train_samples - 1.2, 1e-4, 'nb. params = nb. samples', fontsize = 10, color = 'gray', rotation = 90, rotation_mode='anchor') @@ -111,6 +114,33 @@ fig.savefig('dd-mse.pdf', bbox_inches='tight') plt.close(fig) +###################################################################### +# Plot multiple MSE vs. degree curves + +fig = plt.figure() + +ax = fig.add_subplot(1, 1, 1) +ax.set_yscale('log') +ax.set_ylim(1e-5, 1) +ax.set_xlabel('Polynomial degree', labelpad = 10) +ax.set_ylabel('MSE', labelpad = 10) + +nb_train_samples_min = args.nb_train_samples - 4 +nb_train_samples_max = args.nb_train_samples + +for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2): + mse_train, mse_test = compute_mse(nb_train_samples) + e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min) + e = 0.15 + 0.7 * e + ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}') + ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}') + +ax.legend(frameon = False) + +fig.savefig('dd-multi-mse.pdf', bbox_inches='tight') + +plt.close(fig) + ###################################################################### # Plot some examples of train / test