X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=elbo.py;fp=elbo.py;h=24155fea9d73c1d0ef1f10e5496099c19b839140;hp=0000000000000000000000000000000000000000;hb=c16fa89db08b59e454c6ca4b5c68bf7396e876dc;hpb=5be92059a4bc81db8aad8b677d3387800de2aae8 diff --git a/elbo.py b/elbo.py new file mode 100755 index 0000000..24155fe --- /dev/null +++ b/elbo.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch + +def D_KL(p, q): + return - p @ (q / p).log() + +# p(X = x, Z = z) = p[x, z] +p = torch.rand(5, 4) +p /= p.sum() + +q = torch.rand(p.size()) +q /= q.sum() + +p_X = p.sum(1) +p_Z = p.sum(0) +p_X_given_Z = p / p.sum(0, keepdim = True) +p_Z_given_X = p / p.sum(1, keepdim = True) +q_X_given_Z = q / q.sum(0, keepdim = True) +q_Z_given_X = q / q.sum(1, keepdim = True) + +for x in range(p.size(0)): + elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() + print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))