X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=elbo.py;fp=elbo.py;h=dbea3b5213e3ddc92e23cbc7598d3755e76fe183;hp=6af4a7785e3a110328817569ea478642ed7c10e4;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/elbo.py b/elbo.py index 6af4a77..dbea3b5 100755 --- a/elbo.py +++ b/elbo.py @@ -7,8 +7,10 @@ import torch + def D_KL(a, b): - return - a @ (b / a).log() + return -a @ (b / a).log() + # p(X = x, Z = z) = p[x, z] @@ -19,12 +21,12 @@ q_XZ /= q_XZ.sum() p_X = p_XZ.sum(1) p_Z = p_XZ.sum(0) -p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim = True) -p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim = True) +p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim=True) +p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim=True) -#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True) -q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim = True) +# q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True) +q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim=True) for x in range(p_XZ.size(0)): - elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() + elbo = q_Z_given_X[x, :] @ (p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))