OCD update.
[pytorch.git] / ddpol.py
index 51e7636..35d98a0 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -83,12 +83,6 @@ def compute_mse(nb_train_samples):
 
     return mse_train.median(0).values, mse_test.median(0).values
 
-######################################################################
-
-torch.manual_seed(0)
-
-mse_train, mse_test = compute_mse(args.nb_train_samples)
-
 ######################################################################
 # Plot the MSE vs. degree curves
 
@@ -100,7 +94,14 @@ ax.set_ylim(1e-5, 1)
 ax.set_xlabel('Polynomial degree', labelpad = 10)
 ax.set_ylabel('MSE', labelpad = 10)
 
-ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5)
+ax.axvline(x = args.nb_train_samples - 1,
+           color = 'gray', linewidth = 0.5, linestyle = '--')
+ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
+        fontsize = 10, color = 'gray',
+        rotation = 90, rotation_mode='anchor')
+
+mse_train, mse_test = compute_mse(args.nb_train_samples)
+
 ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
 ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')