Update.
[pytorch.git] / ddpol.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import matplotlib.pyplot as plt
10 import torch
11
12 nb_train_samples = 8
13 D_max = 16
14 nb_runs = 250
15 train_noise_std = 0
16
17 ######################################################################
18
19 def pol_value(alpha, x):
20     x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
21     return x_pow @ alpha
22
23 def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
24     M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
25     B = y
26
27     if D >= 2:
28         q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1)
29         r = q.view(-1,  1)
30         beta = x.new_zeros(D + 1, D + 1)
31         beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
32         l, U = beta.eig(eigenvectors = True)
33         Q = U @ torch.diag(l[:, 0].pow(0.5))
34         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
35         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
36
37     return torch.lstsq(B, M).solution.view(-1)[:D+1]
38
39 ######################################################################
40
41 def phi(x):
42     # return 4 * (x - 0.6) ** 2 * (x >= 0.6) - 4 * (x - 0.4) ** 2 * (x <= 0.4) + 0.5
43     # return 4 * (x - 0.5) ** 2 * (x >= 0.5)
44     return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
45     # return x/2 - torch.sign(x-0.4) * 0.3
46
47 ######################################################################
48
49 torch.manual_seed(0)
50
51 mse_train = torch.zeros(nb_runs, D_max + 1)
52 mse_test = torch.zeros(nb_runs, D_max + 1)
53
54 for k in range(nb_runs):
55     x_train = torch.rand(nb_train_samples, dtype = torch.float64)
56     # x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64)
57     y_train = phi(x_train)
58     if train_noise_std > 0:
59         y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std)
60     x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
61     y_test = phi(x_test)
62
63     for D in range(D_max + 1):
64         alpha = fit_alpha(x_train, y_train, D)
65         mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
66         mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
67
68 mse_train = mse_train.median(0).values
69 mse_test = mse_test.median(0).values
70
71 ######################################################################
72 # Plot the MSE vs. degree curves
73
74 fig = plt.figure()
75
76 ax = fig.add_subplot(1, 1, 1)
77 ax.set_yscale('log')
78 ax.set_ylim(1e-5, 1)
79 ax.set_xlabel('Polynomial degree', labelpad = 10)
80 ax.set_ylabel('MSE', labelpad = 10)
81
82 ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5)
83 ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error')
84 ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error')
85
86 ax.legend(frameon = False)
87
88 fig.savefig('dd-mse.pdf', bbox_inches='tight')
89
90 ######################################################################
91 # Plot some examples of train / test
92
93 torch.manual_seed(5) # I picked that for pretty
94
95 x_train = torch.rand(nb_train_samples, dtype = torch.float64)
96 # x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64)
97 y_train = phi(x_train)
98 if train_noise_std > 0:
99     y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std)
100 x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
101 y_test = phi(x_test)
102
103 for D in range(D_max + 1):
104     fig = plt.figure()
105
106     ax = fig.add_subplot(1, 1, 1)
107     ax.set_title(f'Degree {D}')
108     ax.set_ylim(-0.1, 1.1)
109     ax.plot(x_test, y_test, color = 'blue', label = 'Test values')
110     ax.scatter(x_train, y_train, color = 'blue', label = 'Training examples')
111
112     alpha = fit_alpha(x_train, y_train, D)
113     ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial')
114
115     ax.legend(frameon = False)
116
117     fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight')
118
119 ######################################################################