Added the generation of dd-multi-mse.pdf
authorFrancois Fleuret <francois@fleuret.org>
Tue, 30 Jun 2020 09:06:04 +0000 (11:06 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 30 Jun 2020 09:06:04 +0000 (11:06 +0200)
ddpol.py

index 9d14d2a..6ace38d 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -111,6 +111,31 @@ fig.savefig('dd-mse.pdf', bbox_inches='tight')
 
 plt.close(fig)
 
+######################################################################
+# Plot multiple MSE vs. degree curves
+
+fig = plt.figure()
+
+ax = fig.add_subplot(1, 1, 1)
+ax.set_yscale('log')
+ax.set_ylim(1e-5, 1)
+ax.set_xlabel('Polynomial degree', labelpad = 10)
+ax.set_ylabel('MSE', labelpad = 10)
+
+nb_train_samples_min = args.nb_train_samples - 4
+nb_train_samples_max = args.nb_train_samples
+
+for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
+    mse_train, mse_test = compute_mse(nb_train_samples)
+    e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min)
+    e = 0.15 + 0.7 * e
+    ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
+    ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
+
+fig.savefig('dd-multi-mse.pdf', bbox_inches='tight')
+
+plt.close(fig)
+
 ######################################################################
 # Plot some examples of train / test