-ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5)
-ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
-ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
+nb_train_samples_min = args.nb_train_samples - 4
+nb_train_samples_max = args.nb_train_samples
+
+for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
+ mse_train, mse_test = compute_mse(nb_train_samples)
+ e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min)
+ e = 0.15 + 0.7 * e
+ ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
+ ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')