X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=ddpol.py;h=f33b0a1e08be94a8c7e0cbd158dc36c96fa68618;hb=c16fa89db08b59e454c6ca4b5c68bf7396e876dc;hp=97d2ff5b8cfb9e0d868ef04b925b75f0d9b6b0ae;hpb=606f9f7294872c8291ea6449d3b88f474b21adc9;p=pytorch.git diff --git a/ddpol.py b/ddpol.py index 97d2ff5..f33b0a1 100755 --- a/ddpol.py +++ b/ddpol.py @@ -5,105 +5,169 @@ # Written by Francois Fleuret -import math +import math, argparse import matplotlib.pyplot as plt + import torch ###################################################################### -def compute_alpha(x, y, D, a = 0, b = 1, rho = 1e-11): +parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.') + +parser.add_argument('--D-max', + type = int, default = 16) + +parser.add_argument('--nb-runs', + type = int, default = 250) + +parser.add_argument('--nb-train-samples', + type = int, default = 8) + +parser.add_argument('--train-noise-std', + type = float, default = 0.) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) + +###################################################################### + +def pol_value(alpha, x): + x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1) + return x_pow @ alpha + +def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1) B = y - if D+1 > 2: - q = torch.arange(2, D + 1).view( 1, -1).to(x.dtype) + if D >= 2: + q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1) r = q.view(-1, 1) beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].pow(0.5)) + Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - alpha = torch.lstsq(B, M).solution.view(-1)[:D+1] - - return alpha + return torch.lstsq(B, M).solution[:D+1, 0] ###################################################################### +# The "ground truth" + def phi(x): - return 4 * (x - 0.5) ** 2 * (x >= 0.5) + return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1 ###################################################################### -torch.manual_seed(0) +def compute_mse(nb_train_samples): + mse_train = torch.zeros(args.nb_runs, args.D_max + 1) + mse_test = torch.zeros(args.nb_runs, args.D_max + 1) -nb_train_samples = 7 -D_max = 16 -nb_runs = 250 + for k in range(args.nb_runs): + x_train = torch.rand(nb_train_samples, dtype = torch.float64) + y_train = phi(x_train) + if args.train_noise_std > 0: + y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std) + x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) + y_test = phi(x_test) -mse_train = torch.zeros(nb_runs, D_max + 1) -mse_test = torch.zeros(nb_runs, D_max + 1) + for D in range(args.D_max + 1): + alpha = fit_alpha(x_train, y_train, D) + mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean() + mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean() -for k in range(nb_runs): - x_train = torch.rand(nb_train_samples, dtype = torch.float64) - y_train = phi(x_train) - y_train = y_train + torch.empty(y_train.size(), dtype = y_train.dtype).normal_(0, 0.1) - x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) - y_test = phi(x_test) - - for D in range(D_max + 1): - alpha = compute_alpha(x_train, y_train, D) - X_train = x_train.view(-1, 1) ** torch.arange(D + 1).view(1, -1) - X_test = x_test.view(-1, 1) ** torch.arange(D + 1).view(1, -1) - mse_train[k, D] = ((X_train @ alpha - y_train)**2).mean() - mse_test[k, D] = ((X_test @ alpha - y_test)**2).mean() - -mse_train = mse_train.median(0).values -mse_test = mse_test.median(0).values + return mse_train.median(0).values, mse_test.median(0).values ###################################################################### +# Plot the MSE vs. degree curves -torch.manual_seed(4) # I picked that for pretty +fig = plt.figure() -x_train = torch.rand(nb_train_samples, dtype = torch.float64) -y_train = phi(x_train) -y_train = y_train + torch.empty(y_train.size(), dtype = y_train.dtype).normal_(0, 0.1) -x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) -y_test = phi(x_test) +ax = fig.add_subplot(1, 1, 1) +ax.set_yscale('log') +ax.set_ylim(1e-5, 1) +ax.set_xlabel('Polynomial degree', labelpad = 10) +ax.set_ylabel('MSE', labelpad = 10) -for D in range(D_max + 1): - fig = plt.figure() +ax.axvline(x = args.nb_train_samples - 1, + color = 'gray', linewidth = 0.5, linestyle = '--') - ax = fig.add_subplot(1, 1, 1) - ax.set_title(f'Degree {D}') - ax.set_ylim(-0.1, 1.1) - ax.plot(x_test, y_test, color = 'blue', label = 'Test values') - ax.scatter(x_train, y_train, color = 'blue', label = 'Training examples') +ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples', + fontsize = 10, color = 'gray', + rotation = 90, rotation_mode='anchor') - alpha = compute_alpha(x_train, y_train, D) - X_test = x_test.view(-1, 1) ** torch.arange(D + 1).view(1, -1) - ax.plot(x_test, X_test @ alpha, color = 'red', label = 'Fitted polynomial') +mse_train, mse_test = compute_mse(args.nb_train_samples) +ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train') +ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test') - ax.legend(frameon = False) +ax.legend(frameon = False) - fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight') +fig.savefig('dd-mse.pdf', bbox_inches='tight') + +plt.close(fig) ###################################################################### +# Plot multiple MSE vs. degree curves fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.set_yscale('log') +ax.set_ylim(1e-5, 1) ax.set_xlabel('Polynomial degree', labelpad = 10) ax.set_ylabel('MSE', labelpad = 10) -ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5) -ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error') -ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error') +nb_train_samples_min = args.nb_train_samples - 4 +nb_train_samples_max = args.nb_train_samples + +for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2): + mse_train, mse_test = compute_mse(nb_train_samples) + e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min) + e = 0.15 + 0.7 * e + ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}') + ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}') ax.legend(frameon = False) -fig.savefig('dd-mse.pdf', bbox_inches='tight') +fig.savefig('dd-multi-mse.pdf', bbox_inches='tight') + +plt.close(fig) + +###################################################################### +# Plot some examples of train / test + +torch.manual_seed(9) # I picked that for pretty + +x_train = torch.rand(args.nb_train_samples, dtype = torch.float64) +y_train = phi(x_train) +if args.train_noise_std > 0: + y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std) +x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) +y_test = phi(x_test) + +for D in range(args.D_max + 1): + fig = plt.figure() + + ax = fig.add_subplot(1, 1, 1) + ax.set_title(f'Degree {D}') + ax.set_ylim(-0.1, 1.1) + ax.plot(x_test, y_test, color = 'black', label = 'Test values') + ax.scatter(x_train, y_train, color = 'blue', label = 'Train samples') + + alpha = fit_alpha(x_train, y_train, D) + ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial') + + ax.legend(frameon = False) + + fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight') + + plt.close(fig) ######################################################################