3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import matplotlib.pyplot as plt
13 ######################################################################
15 parser = argparse.ArgumentParser(
16 description="Example of double descent with polynomial regression."
19 parser.add_argument("--D-max", type=int, default=16)
21 parser.add_argument("--nb-runs", type=int, default=250)
23 parser.add_argument("--nb-train-samples", type=int, default=8)
25 parser.add_argument("--train-noise-std", type=float, default=0.0)
28 "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
31 args = parser.parse_args()
34 torch.manual_seed(args.seed)
36 ######################################################################
39 def pol_value(alpha, x):
40 x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
44 def fit_alpha(x, y, D, a=0, b=1, rho=1e-12):
45 M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
49 q = torch.arange(2, D + 1, dtype=x.dtype).view(1, -1)
51 beta = x.new_zeros(D + 1, D + 1)
57 * (b ** (q + r - 3) - a ** (q + r - 3))
60 W = torch.linalg.eig(beta)
61 l, U = W.eigenvalues.real, W.eigenvectors.real
62 Q = U @ torch.diag(l.clamp(min=0) ** 0.5) # clamp deals with ~0 negative values
63 B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
64 M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
66 return torch.linalg.lstsq(M, B).solution[: D + 1]
69 ######################################################################
75 return torch.abs(torch.abs(x - 0.4) - 0.2) + x / 2 - 0.1
78 ######################################################################
81 def compute_mse(nb_train_samples):
82 mse_train = torch.zeros(args.nb_runs, args.D_max + 1)
83 mse_test = torch.zeros(args.nb_runs, args.D_max + 1)
85 for k in range(args.nb_runs):
86 x_train = torch.rand(nb_train_samples, dtype=torch.float64)
87 y_train = phi(x_train)
88 if args.train_noise_std > 0:
89 y_train = y_train + torch.empty_like(y_train).normal_(
90 0, args.train_noise_std
92 x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype)
95 for D in range(args.D_max + 1):
96 alpha = fit_alpha(x_train, y_train, D)
97 mse_train[k, D] = ((pol_value(alpha, x_train) - y_train) ** 2).mean()
98 mse_test[k, D] = ((pol_value(alpha, x_test) - y_test) ** 2).mean()
100 return mse_train.median(0).values, mse_test.median(0).values
103 ######################################################################
104 # Plot the MSE vs. degree curves
108 ax = fig.add_subplot(1, 1, 1)
111 ax.set_xlabel("Polynomial degree", labelpad=10)
112 ax.set_ylabel("MSE", labelpad=10)
114 ax.axvline(x=args.nb_train_samples - 1, color="gray", linewidth=0.5, linestyle="--")
117 args.nb_train_samples - 1.2,
119 "nb. params = nb. samples",
123 rotation_mode="anchor",
126 mse_train, mse_test = compute_mse(args.nb_train_samples)
127 ax.plot(torch.arange(args.D_max + 1), mse_train, color="blue", label="Train")
128 ax.plot(torch.arange(args.D_max + 1), mse_test, color="red", label="Test")
130 ax.legend(frameon=False)
132 fig.savefig("dd-mse.pdf", bbox_inches="tight")
136 ######################################################################
137 # Plot multiple MSE vs. degree curves
141 ax = fig.add_subplot(1, 1, 1)
144 ax.set_xlabel("Polynomial degree", labelpad=10)
145 ax.set_ylabel("MSE", labelpad=10)
147 nb_train_samples_min = args.nb_train_samples - 4
148 nb_train_samples_max = args.nb_train_samples
150 for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
151 mse_train, mse_test = compute_mse(nb_train_samples)
152 e = float(nb_train_samples - nb_train_samples_min) / float(
153 nb_train_samples_max - nb_train_samples_min
157 torch.arange(args.D_max + 1),
160 label=f"Train N={nb_train_samples}",
163 torch.arange(args.D_max + 1),
166 label=f"Test N={nb_train_samples}",
169 ax.legend(frameon=False)
171 fig.savefig("dd-multi-mse.pdf", bbox_inches="tight")
175 ######################################################################
176 # Plot some examples of train / test
178 torch.manual_seed(9) # I picked that for pretty
180 x_train = torch.rand(args.nb_train_samples, dtype=torch.float64)
181 y_train = phi(x_train)
182 if args.train_noise_std > 0:
183 y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
184 x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype)
187 for D in range(args.D_max + 1):
190 ax = fig.add_subplot(1, 1, 1)
191 ax.set_title(f"Degree {D}")
192 ax.set_ylim(-0.1, 1.1)
193 ax.plot(x_test, y_test, color="black", label="Test values")
194 ax.scatter(x_train, y_train, color="blue", label="Train samples")
196 alpha = fit_alpha(x_train, y_train, D)
197 ax.plot(x_test, pol_value(alpha, x_test), color="red", label="Fitted polynomial")
199 ax.legend(frameon=False)
201 fig.savefig(f"dd-example-{D:02d}.pdf", bbox_inches="tight")
205 ######################################################################