X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=ddpol.py;h=1c5d1316e576d900f3cc1c58a16c71b2379de332;hp=6ace38d1c155164d454ca8d6d49b0f9e4475b262;hb=c75c1288649c1d4bbd941ba040367deffb5feccb;hpb=437a0746551145f241b39d4a95ae28ecd1410a54 diff --git a/ddpol.py b/ddpol.py index 6ace38d..1c5d131 100755 --- a/ddpol.py +++ b/ddpol.py @@ -132,6 +132,8 @@ for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2) ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}') ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}') +ax.legend(frameon = False) + fig.savefig('dd-multi-mse.pdf', bbox_inches='tight') plt.close(fig)