Update.
[pytorch.git] / ddpol.py
index 35d98a0..1975ab2 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -12,23 +12,21 @@ import torch
 
 ######################################################################
 
-parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.')
+parser = argparse.ArgumentParser(
+    description="Example of double descent with polynomial regression."
+)
 
-parser.add_argument('--D-max',
-                    type = int, default = 16)
+parser.add_argument("--D-max", type=int, default=16)
 
-parser.add_argument('--nb-runs',
-                    type = int, default = 250)
+parser.add_argument("--nb-runs", type=int, default=250)
 
-parser.add_argument('--nb-train-samples',
-                    type = int, default = 8)
+parser.add_argument("--nb-train-samples", type=int, default=8)
 
-parser.add_argument('--train-noise-std',
-                    type = float, default = 0.)
+parser.add_argument("--train-noise-std", type=float, default=0.0)
 
-parser.add_argument('--seed',
-                    type = int, default = 0,
-                    help = 'Random seed (default 0, < 0 is no seeding)')
+parser.add_argument(
+    "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
+)
 
 args = parser.parse_args()
 
@@ -37,107 +35,170 @@ if args.seed >= 0:
 
 ######################################################################
 
+
 def pol_value(alpha, x):
     x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
     return x_pow @ alpha
 
-def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
+
+def fit_alpha(x, y, D, a=0, b=1, rho=1e-12):
     M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
     B = y
 
     if D >= 2:
-        q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1)
-        r = q.view(-1,  1)
+        q = torch.arange(2, D + 1, dtype=x.dtype).view(1, -1)
+        r = q.view(-1, 1)
         beta = x.new_zeros(D + 1, D + 1)
-        beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
-        l, U = beta.eig(eigenvectors = True)
-        Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5)
+        beta[2:, 2:] = (
+            (q - 1)
+            * q
+            * (r - 1)
+            * r
+            * (b ** (q + r - 3) - a ** (q + r - 3))
+            / (q + r - 3)
+        )
+        W = torch.linalg.eig(beta)
+        l, U = W.eigenvalues.real, W.eigenvectors.real
+        Q = U @ torch.diag(l.clamp(min=0) ** 0.5)  # clamp deals with ~0 negative values
         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
 
-    return torch.lstsq(B, M).solution[:D+1, 0]
+    return torch.linalg.lstsq(M, B).solution[: D + 1]
+
 
 ######################################################################
 
+# The "ground truth"
+
+
 def phi(x):
-    return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
+    return torch.abs(torch.abs(x - 0.4) - 0.2) + x / 2 - 0.1
+
 
 ######################################################################
 
+
 def compute_mse(nb_train_samples):
     mse_train = torch.zeros(args.nb_runs, args.D_max + 1)
     mse_test = torch.zeros(args.nb_runs, args.D_max + 1)
 
     for k in range(args.nb_runs):
-        x_train = torch.rand(nb_train_samples, dtype = torch.float64)
+        x_train = torch.rand(nb_train_samples, dtype=torch.float64)
         y_train = phi(x_train)
         if args.train_noise_std > 0:
-            y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
-        x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
+            y_train = y_train + torch.empty_like(y_train).normal_(
+                0, args.train_noise_std
+            )
+        x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype)
         y_test = phi(x_test)
 
         for D in range(args.D_max + 1):
             alpha = fit_alpha(x_train, y_train, D)
-            mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
-            mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
+            mse_train[k, D] = ((pol_value(alpha, x_train) - y_train) ** 2).mean()
+            mse_test[k, D] = ((pol_value(alpha, x_test) - y_test) ** 2).mean()
 
     return mse_train.median(0).values, mse_test.median(0).values
 
+
 ######################################################################
 # Plot the MSE vs. degree curves
 
 fig = plt.figure()
 
 ax = fig.add_subplot(1, 1, 1)
-ax.set_yscale('log')
+ax.set_yscale("log")
 ax.set_ylim(1e-5, 1)
-ax.set_xlabel('Polynomial degree', labelpad = 10)
-ax.set_ylabel('MSE', labelpad = 10)
+ax.set_xlabel("Polynomial degree", labelpad=10)
+ax.set_ylabel("MSE", labelpad=10)
+
+ax.axvline(x=args.nb_train_samples - 1, color="gray", linewidth=0.5, linestyle="--")
 
-ax.axvline(x = args.nb_train_samples - 1,
-           color = 'gray', linewidth = 0.5, linestyle = '--')
-ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
-        fontsize = 10, color = 'gray',
-        rotation = 90, rotation_mode='anchor')
+ax.text(
+    args.nb_train_samples - 1.2,
+    1e-4,
+    "nb. params = nb. samples",
+    fontsize=10,
+    color="gray",
+    rotation=90,
+    rotation_mode="anchor",
+)
 
 mse_train, mse_test = compute_mse(args.nb_train_samples)
+ax.plot(torch.arange(args.D_max + 1), mse_train, color="blue", label="Train")
+ax.plot(torch.arange(args.D_max + 1), mse_test, color="red", label="Test")
 
-ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
-ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
+ax.legend(frameon=False)
 
-ax.legend(frameon = False)
+fig.savefig("dd-mse.pdf", bbox_inches="tight")
 
-fig.savefig('dd-mse.pdf', bbox_inches='tight')
+plt.close(fig)
+
+######################################################################
+# Plot multiple MSE vs. degree curves
+
+fig = plt.figure()
+
+ax = fig.add_subplot(1, 1, 1)
+ax.set_yscale("log")
+ax.set_ylim(1e-5, 1)
+ax.set_xlabel("Polynomial degree", labelpad=10)
+ax.set_ylabel("MSE", labelpad=10)
+
+nb_train_samples_min = args.nb_train_samples - 4
+nb_train_samples_max = args.nb_train_samples
+
+for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
+    mse_train, mse_test = compute_mse(nb_train_samples)
+    e = float(nb_train_samples - nb_train_samples_min) / float(
+        nb_train_samples_max - nb_train_samples_min
+    )
+    e = 0.15 + 0.7 * e
+    ax.plot(
+        torch.arange(args.D_max + 1),
+        mse_train,
+        color=(e, e, 1.0),
+        label=f"Train N={nb_train_samples}",
+    )
+    ax.plot(
+        torch.arange(args.D_max + 1),
+        mse_test,
+        color=(1.0, e, e),
+        label=f"Test N={nb_train_samples}",
+    )
+
+ax.legend(frameon=False)
+
+fig.savefig("dd-multi-mse.pdf", bbox_inches="tight")
 
 plt.close(fig)
 
 ######################################################################
 # Plot some examples of train / test
 
-torch.manual_seed(9) # I picked that for pretty
+torch.manual_seed(9)  # I picked that for pretty
 
-x_train = torch.rand(args.nb_train_samples, dtype = torch.float64)
+x_train = torch.rand(args.nb_train_samples, dtype=torch.float64)
 y_train = phi(x_train)
 if args.train_noise_std > 0:
     y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
-x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
+x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype)
 y_test = phi(x_test)
 
 for D in range(args.D_max + 1):
     fig = plt.figure()
 
     ax = fig.add_subplot(1, 1, 1)
-    ax.set_title(f'Degree {D}')
+    ax.set_title(f"Degree {D}")
     ax.set_ylim(-0.1, 1.1)
-    ax.plot(x_test, y_test, color = 'black', label = 'Test values')
-    ax.scatter(x_train, y_train, color = 'blue', label = 'Train samples')
+    ax.plot(x_test, y_test, color="black", label="Test values")
+    ax.scatter(x_train, y_train, color="blue", label="Train samples")
 
     alpha = fit_alpha(x_train, y_train, D)
-    ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial')
+    ax.plot(x_test, pol_value(alpha, x_test), color="red", label="Fitted polynomial")
 
-    ax.legend(frameon = False)
+    ax.legend(frameon=False)
 
-    fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight')
+    fig.savefig(f"dd-example-{D:02d}.pdf", bbox_inches="tight")
 
     plt.close(fig)