X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=elbo.py;fp=elbo.py;h=6af4a7785e3a110328817569ea478642ed7c10e4;hp=24155fea9d73c1d0ef1f10e5496099c19b839140;hb=47525ec795faca1ab72aee13956a553d070c81b7;hpb=c16fa89db08b59e454c6ca4b5c68bf7396e876dc diff --git a/elbo.py b/elbo.py index 24155fe..6af4a77 100755 --- a/elbo.py +++ b/elbo.py @@ -7,23 +7,24 @@ import torch -def D_KL(p, q): - return - p @ (q / p).log() +def D_KL(a, b): + return - a @ (b / a).log() # p(X = x, Z = z) = p[x, z] -p = torch.rand(5, 4) -p /= p.sum() -q = torch.rand(p.size()) -q /= q.sum() +p_XZ = torch.rand(5, 4) +p_XZ /= p_XZ.sum() +q_XZ = torch.rand(p_XZ.size()) +q_XZ /= q_XZ.sum() -p_X = p.sum(1) -p_Z = p.sum(0) -p_X_given_Z = p / p.sum(0, keepdim = True) -p_Z_given_X = p / p.sum(1, keepdim = True) -q_X_given_Z = q / q.sum(0, keepdim = True) -q_Z_given_X = q / q.sum(1, keepdim = True) +p_X = p_XZ.sum(1) +p_Z = p_XZ.sum(0) +p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim = True) +p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim = True) -for x in range(p.size(0)): +#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True) +q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim = True) + +for x in range(p_XZ.size(0)): elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))