Initial commit.
[pytorch.git] / ddpol.py
index 6ace38d..f33b0a1 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -59,6 +59,8 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
 
 ######################################################################
 
+# The "ground truth"
+
 def phi(x):
     return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
 
@@ -132,6 +134,8 @@ for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2)
     ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
     ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
 
+ax.legend(frameon = False)
+
 fig.savefig('dd-multi-mse.pdf', bbox_inches='tight')
 
 plt.close(fig)