Update.
[pytorch.git] / ddpol.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9 import matplotlib.pyplot as plt
10
11 import torch
12
13 ######################################################################
14
15 parser = argparse.ArgumentParser(
16     description="Example of double descent with polynomial regression."
17 )
18
19 parser.add_argument("--D-max", type=int, default=16)
20
21 parser.add_argument("--nb-runs", type=int, default=250)
22
23 parser.add_argument("--nb-train-samples", type=int, default=8)
24
25 parser.add_argument("--train-noise-std", type=float, default=0.0)
26
27 parser.add_argument(
28     "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
29 )
30
31 args = parser.parse_args()
32
33 if args.seed >= 0:
34     torch.manual_seed(args.seed)
35
36 ######################################################################
37
38
39 def pol_value(alpha, x):
40     x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
41     return x_pow @ alpha
42
43
44 def fit_alpha(x, y, D, a=0, b=1, rho=1e-12):
45     M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
46     B = y
47
48     if D >= 2:
49         q = torch.arange(2, D + 1, dtype=x.dtype).view(1, -1)
50         r = q.view(-1, 1)
51         beta = x.new_zeros(D + 1, D + 1)
52         beta[2:, 2:] = (
53             (q - 1)
54             * q
55             * (r - 1)
56             * r
57             * (b ** (q + r - 3) - a ** (q + r - 3))
58             / (q + r - 3)
59         )
60         W = torch.linalg.eig(beta)
61         l, U = W.eigenvalues.real, W.eigenvectors.real
62         Q = U @ torch.diag(l.clamp(min=0) ** 0.5)  # clamp deals with ~0 negative values
63         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
64         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
65
66     return torch.linalg.lstsq(M, B).solution[: D + 1]
67
68
69 ######################################################################
70
71 # The "ground truth"
72
73
74 def phi(x):
75     return torch.abs(torch.abs(x - 0.4) - 0.2) + x / 2 - 0.1
76
77
78 ######################################################################
79
80
81 def compute_mse(nb_train_samples):
82     mse_train = torch.zeros(args.nb_runs, args.D_max + 1)
83     mse_test = torch.zeros(args.nb_runs, args.D_max + 1)
84
85     for k in range(args.nb_runs):
86         x_train = torch.rand(nb_train_samples, dtype=torch.float64)
87         y_train = phi(x_train)
88         if args.train_noise_std > 0:
89             y_train = y_train + torch.empty_like(y_train).normal_(
90                 0, args.train_noise_std
91             )
92         x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype)
93         y_test = phi(x_test)
94
95         for D in range(args.D_max + 1):
96             alpha = fit_alpha(x_train, y_train, D)
97             mse_train[k, D] = ((pol_value(alpha, x_train) - y_train) ** 2).mean()
98             mse_test[k, D] = ((pol_value(alpha, x_test) - y_test) ** 2).mean()
99
100     return mse_train.median(0).values, mse_test.median(0).values
101
102
103 ######################################################################
104 # Plot the MSE vs. degree curves
105
106 fig = plt.figure()
107
108 ax = fig.add_subplot(1, 1, 1)
109 ax.set_yscale("log")
110 ax.set_ylim(1e-5, 1)
111 ax.set_xlabel("Polynomial degree", labelpad=10)
112 ax.set_ylabel("MSE", labelpad=10)
113
114 ax.axvline(x=args.nb_train_samples - 1, color="gray", linewidth=0.5, linestyle="--")
115
116 ax.text(
117     args.nb_train_samples - 1.2,
118     1e-4,
119     "nb. params = nb. samples",
120     fontsize=10,
121     color="gray",
122     rotation=90,
123     rotation_mode="anchor",
124 )
125
126 mse_train, mse_test = compute_mse(args.nb_train_samples)
127 ax.plot(torch.arange(args.D_max + 1), mse_train, color="blue", label="Train")
128 ax.plot(torch.arange(args.D_max + 1), mse_test, color="red", label="Test")
129
130 ax.legend(frameon=False)
131
132 fig.savefig("dd-mse.pdf", bbox_inches="tight")
133
134 plt.close(fig)
135
136 ######################################################################
137 # Plot multiple MSE vs. degree curves
138
139 fig = plt.figure()
140
141 ax = fig.add_subplot(1, 1, 1)
142 ax.set_yscale("log")
143 ax.set_ylim(1e-5, 1)
144 ax.set_xlabel("Polynomial degree", labelpad=10)
145 ax.set_ylabel("MSE", labelpad=10)
146
147 nb_train_samples_min = args.nb_train_samples - 4
148 nb_train_samples_max = args.nb_train_samples
149
150 for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
151     mse_train, mse_test = compute_mse(nb_train_samples)
152     e = float(nb_train_samples - nb_train_samples_min) / float(
153         nb_train_samples_max - nb_train_samples_min
154     )
155     e = 0.15 + 0.7 * e
156     ax.plot(
157         torch.arange(args.D_max + 1),
158         mse_train,
159         color=(e, e, 1.0),
160         label=f"Train N={nb_train_samples}",
161     )
162     ax.plot(
163         torch.arange(args.D_max + 1),
164         mse_test,
165         color=(1.0, e, e),
166         label=f"Test N={nb_train_samples}",
167     )
168
169 ax.legend(frameon=False)
170
171 fig.savefig("dd-multi-mse.pdf", bbox_inches="tight")
172
173 plt.close(fig)
174
175 ######################################################################
176 # Plot some examples of train / test
177
178 torch.manual_seed(9)  # I picked that for pretty
179
180 x_train = torch.rand(args.nb_train_samples, dtype=torch.float64)
181 y_train = phi(x_train)
182 if args.train_noise_std > 0:
183     y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
184 x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype)
185 y_test = phi(x_test)
186
187 for D in range(args.D_max + 1):
188     fig = plt.figure()
189
190     ax = fig.add_subplot(1, 1, 1)
191     ax.set_title(f"Degree {D}")
192     ax.set_ylim(-0.1, 1.1)
193     ax.plot(x_test, y_test, color="black", label="Test values")
194     ax.scatter(x_train, y_train, color="blue", label="Train samples")
195
196     alpha = fit_alpha(x_train, y_train, D)
197     ax.plot(x_test, pol_value(alpha, x_test), color="red", label="Fitted polynomial")
198
199     ax.legend(frameon=False)
200
201     fig.savefig(f"dd-example-{D:02d}.pdf", bbox_inches="tight")
202
203     plt.close(fig)
204
205 ######################################################################