X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=ddpol.py;h=f33b0a1e08be94a8c7e0cbd158dc36c96fa68618;hp=51e76369027f509e7571a9adb853667b7a6aa93d;hb=HEAD;hpb=a4bc783e87679b297f544433b4a7f005c1e115a9 diff --git a/ddpol.py b/ddpol.py index 51e7636..1975ab2 100755 --- a/ddpol.py +++ b/ddpol.py @@ -12,23 +12,21 @@ import torch ###################################################################### -parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.') +parser = argparse.ArgumentParser( + description="Example of double descent with polynomial regression." +) -parser.add_argument('--D-max', - type = int, default = 16) +parser.add_argument("--D-max", type=int, default=16) -parser.add_argument('--nb-runs', - type = int, default = 250) +parser.add_argument("--nb-runs", type=int, default=250) -parser.add_argument('--nb-train-samples', - type = int, default = 8) +parser.add_argument("--nb-train-samples", type=int, default=8) -parser.add_argument('--train-noise-std', - type = float, default = 0.) +parser.add_argument("--train-noise-std", type=float, default=0.0) -parser.add_argument('--seed', - type = int, default = 0, - help = 'Random seed (default 0, < 0 is no seeding)') +parser.add_argument( + "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)" +) args = parser.parse_args() @@ -37,57 +35,70 @@ if args.seed >= 0: ###################################################################### + def pol_value(alpha, x): x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1) return x_pow @ alpha -def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): + +def fit_alpha(x, y, D, a=0, b=1, rho=1e-12): M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1) B = y if D >= 2: - q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1) - r = q.view(-1, 1) + q = torch.arange(2, D + 1, dtype=x.dtype).view(1, -1) + r = q.view(-1, 1) beta = x.new_zeros(D + 1, D + 1) - beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) - l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) + beta[2:, 2:] = ( + (q - 1) + * q + * (r - 1) + * r + * (b ** (q + r - 3) - a ** (q + r - 3)) + / (q + r - 3) + ) + W = torch.linalg.eig(beta) + l, U = W.eigenvalues.real, W.eigenvectors.real + Q = U @ torch.diag(l.clamp(min=0) ** 0.5) # clamp deals with ~0 negative values B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - return torch.lstsq(B, M).solution[:D+1, 0] + return torch.linalg.lstsq(M, B).solution[: D + 1] + ###################################################################### +# The "ground truth" + + def phi(x): - return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1 + return torch.abs(torch.abs(x - 0.4) - 0.2) + x / 2 - 0.1 + ###################################################################### + def compute_mse(nb_train_samples): mse_train = torch.zeros(args.nb_runs, args.D_max + 1) mse_test = torch.zeros(args.nb_runs, args.D_max + 1) for k in range(args.nb_runs): - x_train = torch.rand(nb_train_samples, dtype = torch.float64) + x_train = torch.rand(nb_train_samples, dtype=torch.float64) y_train = phi(x_train) if args.train_noise_std > 0: - y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std) - x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) + y_train = y_train + torch.empty_like(y_train).normal_( + 0, args.train_noise_std + ) + x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype) y_test = phi(x_test) for D in range(args.D_max + 1): alpha = fit_alpha(x_train, y_train, D) - mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean() - mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean() + mse_train[k, D] = ((pol_value(alpha, x_train) - y_train) ** 2).mean() + mse_test[k, D] = ((pol_value(alpha, x_test) - y_test) ** 2).mean() return mse_train.median(0).values, mse_test.median(0).values -###################################################################### - -torch.manual_seed(0) - -mse_train, mse_test = compute_mse(args.nb_train_samples) ###################################################################### # Plot the MSE vs. degree curves @@ -95,48 +106,99 @@ mse_train, mse_test = compute_mse(args.nb_train_samples) fig = plt.figure() ax = fig.add_subplot(1, 1, 1) -ax.set_yscale('log') +ax.set_yscale("log") ax.set_ylim(1e-5, 1) -ax.set_xlabel('Polynomial degree', labelpad = 10) -ax.set_ylabel('MSE', labelpad = 10) +ax.set_xlabel("Polynomial degree", labelpad=10) +ax.set_ylabel("MSE", labelpad=10) -ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5) -ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error') -ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error') +ax.axvline(x=args.nb_train_samples - 1, color="gray", linewidth=0.5, linestyle="--") -ax.legend(frameon = False) +ax.text( + args.nb_train_samples - 1.2, + 1e-4, + "nb. params = nb. samples", + fontsize=10, + color="gray", + rotation=90, + rotation_mode="anchor", +) -fig.savefig('dd-mse.pdf', bbox_inches='tight') +mse_train, mse_test = compute_mse(args.nb_train_samples) +ax.plot(torch.arange(args.D_max + 1), mse_train, color="blue", label="Train") +ax.plot(torch.arange(args.D_max + 1), mse_test, color="red", label="Test") + +ax.legend(frameon=False) + +fig.savefig("dd-mse.pdf", bbox_inches="tight") + +plt.close(fig) + +###################################################################### +# Plot multiple MSE vs. degree curves + +fig = plt.figure() + +ax = fig.add_subplot(1, 1, 1) +ax.set_yscale("log") +ax.set_ylim(1e-5, 1) +ax.set_xlabel("Polynomial degree", labelpad=10) +ax.set_ylabel("MSE", labelpad=10) + +nb_train_samples_min = args.nb_train_samples - 4 +nb_train_samples_max = args.nb_train_samples + +for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2): + mse_train, mse_test = compute_mse(nb_train_samples) + e = float(nb_train_samples - nb_train_samples_min) / float( + nb_train_samples_max - nb_train_samples_min + ) + e = 0.15 + 0.7 * e + ax.plot( + torch.arange(args.D_max + 1), + mse_train, + color=(e, e, 1.0), + label=f"Train N={nb_train_samples}", + ) + ax.plot( + torch.arange(args.D_max + 1), + mse_test, + color=(1.0, e, e), + label=f"Test N={nb_train_samples}", + ) + +ax.legend(frameon=False) + +fig.savefig("dd-multi-mse.pdf", bbox_inches="tight") plt.close(fig) ###################################################################### # Plot some examples of train / test -torch.manual_seed(9) # I picked that for pretty +torch.manual_seed(9) # I picked that for pretty -x_train = torch.rand(args.nb_train_samples, dtype = torch.float64) +x_train = torch.rand(args.nb_train_samples, dtype=torch.float64) y_train = phi(x_train) if args.train_noise_std > 0: y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std) -x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) +x_test = torch.linspace(0, 1, 100, dtype=x_train.dtype) y_test = phi(x_test) for D in range(args.D_max + 1): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) - ax.set_title(f'Degree {D}') + ax.set_title(f"Degree {D}") ax.set_ylim(-0.1, 1.1) - ax.plot(x_test, y_test, color = 'black', label = 'Test values') - ax.scatter(x_train, y_train, color = 'blue', label = 'Train samples') + ax.plot(x_test, y_test, color="black", label="Test values") + ax.scatter(x_train, y_train, color="blue", label="Train samples") alpha = fit_alpha(x_train, y_train, D) - ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial') + ax.plot(x_test, pol_value(alpha, x_test), color="red", label="Fitted polynomial") - ax.legend(frameon = False) + ax.legend(frameon=False) - fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight') + fig.savefig(f"dd-example-{D:02d}.pdf", bbox_inches="tight") plt.close(fig)