Update.
[pytorch.git] / ddpol.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9 import matplotlib.pyplot as plt
10
11 import torch
12
13 ######################################################################
14
15 parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.')
16
17 parser.add_argument('--D-max',
18                     type = int, default = 16)
19
20 parser.add_argument('--nb-runs',
21                     type = int, default = 250)
22
23 parser.add_argument('--nb-train-samples',
24                     type = int, default = 8)
25
26 parser.add_argument('--train-noise-std',
27                     type = float, default = 0.)
28
29 parser.add_argument('--seed',
30                     type = int, default = 0,
31                     help = 'Random seed (default 0, < 0 is no seeding)')
32
33 args = parser.parse_args()
34
35 if args.seed >= 0:
36     torch.manual_seed(args.seed)
37
38 ######################################################################
39
40 def pol_value(alpha, x):
41     x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
42     return x_pow @ alpha
43
44 def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
45     M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
46     B = y
47
48     if D >= 2:
49         q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1)
50         r = q.view(-1,  1)
51         beta = x.new_zeros(D + 1, D + 1)
52         beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
53         W = torch.linalg.eig(beta)
54         l, U = W.eigenvalues.real, W.eigenvectors.real
55         Q = U @ torch.diag(l.clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values
56         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
57         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
58
59     return torch.linalg.lstsq(M, B).solution[:D+1]
60
61 ######################################################################
62
63 # The "ground truth"
64
65 def phi(x):
66     return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
67
68 ######################################################################
69
70 def compute_mse(nb_train_samples):
71     mse_train = torch.zeros(args.nb_runs, args.D_max + 1)
72     mse_test = torch.zeros(args.nb_runs, args.D_max + 1)
73
74     for k in range(args.nb_runs):
75         x_train = torch.rand(nb_train_samples, dtype = torch.float64)
76         y_train = phi(x_train)
77         if args.train_noise_std > 0:
78             y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
79         x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
80         y_test = phi(x_test)
81
82         for D in range(args.D_max + 1):
83             alpha = fit_alpha(x_train, y_train, D)
84             mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
85             mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
86
87     return mse_train.median(0).values, mse_test.median(0).values
88
89 ######################################################################
90 # Plot the MSE vs. degree curves
91
92 fig = plt.figure()
93
94 ax = fig.add_subplot(1, 1, 1)
95 ax.set_yscale('log')
96 ax.set_ylim(1e-5, 1)
97 ax.set_xlabel('Polynomial degree', labelpad = 10)
98 ax.set_ylabel('MSE', labelpad = 10)
99
100 ax.axvline(x = args.nb_train_samples - 1,
101            color = 'gray', linewidth = 0.5, linestyle = '--')
102
103 ax.text(args.nb_train_samples - 1.2, 1e-4, 'nb. params = nb. samples',
104         fontsize = 10, color = 'gray',
105         rotation = 90, rotation_mode='anchor')
106
107 mse_train, mse_test = compute_mse(args.nb_train_samples)
108 ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train')
109 ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test')
110
111 ax.legend(frameon = False)
112
113 fig.savefig('dd-mse.pdf', bbox_inches='tight')
114
115 plt.close(fig)
116
117 ######################################################################
118 # Plot multiple MSE vs. degree curves
119
120 fig = plt.figure()
121
122 ax = fig.add_subplot(1, 1, 1)
123 ax.set_yscale('log')
124 ax.set_ylim(1e-5, 1)
125 ax.set_xlabel('Polynomial degree', labelpad = 10)
126 ax.set_ylabel('MSE', labelpad = 10)
127
128 nb_train_samples_min = args.nb_train_samples - 4
129 nb_train_samples_max = args.nb_train_samples
130
131 for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
132     mse_train, mse_test = compute_mse(nb_train_samples)
133     e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min)
134     e = 0.15 + 0.7 * e
135     ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
136     ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
137
138 ax.legend(frameon = False)
139
140 fig.savefig('dd-multi-mse.pdf', bbox_inches='tight')
141
142 plt.close(fig)
143
144 ######################################################################
145 # Plot some examples of train / test
146
147 torch.manual_seed(9) # I picked that for pretty
148
149 x_train = torch.rand(args.nb_train_samples, dtype = torch.float64)
150 y_train = phi(x_train)
151 if args.train_noise_std > 0:
152     y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
153 x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
154 y_test = phi(x_test)
155
156 for D in range(args.D_max + 1):
157     fig = plt.figure()
158
159     ax = fig.add_subplot(1, 1, 1)
160     ax.set_title(f'Degree {D}')
161     ax.set_ylim(-0.1, 1.1)
162     ax.plot(x_test, y_test, color = 'black', label = 'Test values')
163     ax.scatter(x_train, y_train, color = 'blue', label = 'Train samples')
164
165     alpha = fit_alpha(x_train, y_train, D)
166     ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial')
167
168     ax.legend(frameon = False)
169
170     fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight')
171
172     plt.close(fig)
173
174 ######################################################################