From fbc0246e209dd03b9c9865904039b565d7fa9f96 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 22 Jun 2020 10:08:54 +0200 Subject: [PATCH] Cleanup. --- ddpol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ddpol.py b/ddpol.py index 35d98a0..9d14d2a 100755 --- a/ddpol.py +++ b/ddpol.py @@ -51,7 +51,7 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) + Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) @@ -96,14 +96,14 @@ ax.set_ylabel('MSE', labelpad = 10) ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5, linestyle = '--') + ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples', fontsize = 10, color = 'gray', rotation = 90, rotation_mode='anchor') mse_train, mse_test = compute_mse(args.nb_train_samples) - -ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error') -ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error') +ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train') +ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test') ax.legend(frameon = False) -- 2.20.1