X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=ddpol.py;h=f33b0a1e08be94a8c7e0cbd158dc36c96fa68618;hb=c9b738deac0ba9378509e684273de90089a2f5d7;hp=6ace38d1c155164d454ca8d6d49b0f9e4475b262;hpb=437a0746551145f241b39d4a95ae28ecd1410a54;p=pytorch.git diff --git a/ddpol.py b/ddpol.py index 6ace38d..f33b0a1 100755 --- a/ddpol.py +++ b/ddpol.py @@ -59,6 +59,8 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): ###################################################################### +# The "ground truth" + def phi(x): return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1 @@ -132,6 +134,8 @@ for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2) ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}') ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}') +ax.legend(frameon = False) + fig.savefig('dd-multi-mse.pdf', bbox_inches='tight') plt.close(fig)