r = q.view(-1, 1)
beta = x.new_zeros(D + 1, D + 1)
beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
- l, U = beta.eig(eigenvectors = True)
- Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values
+ W = torch.linalg.eig(beta)
+ l, U = W.eigenvalues.real, W.eigenvectors.real
+ Q = U @ torch.diag(l.clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values
B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
- return torch.lstsq(B, M).solution[:D+1, 0]
+ return torch.linalg.lstsq(M, B).solution[:D+1]
######################################################################
+# The "ground truth"
+
def phi(x):
return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
ax.axvline(x = args.nb_train_samples - 1,
color = 'gray', linewidth = 0.5, linestyle = '--')
-ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
+ax.text(args.nb_train_samples - 1.2, 1e-4, 'nb. params = nb. samples',
fontsize = 10, color = 'gray',
rotation = 90, rotation_mode='anchor')