From: Francois Fleuret Date: Fri, 10 Dec 2021 21:53:37 +0000 (+0100) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=c16fa89db08b59e454c6ca4b5c68bf7396e876dc Update. --- diff --git a/elbo.py b/elbo.py new file mode 100755 index 0000000..24155fe --- /dev/null +++ b/elbo.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch + +def D_KL(p, q): + return - p @ (q / p).log() + +# p(X = x, Z = z) = p[x, z] +p = torch.rand(5, 4) +p /= p.sum() + +q = torch.rand(p.size()) +q /= q.sum() + +p_X = p.sum(1) +p_Z = p.sum(0) +p_X_given_Z = p / p.sum(0, keepdim = True) +p_Z_given_X = p / p.sum(1, keepdim = True) +q_X_given_Z = q / q.sum(0, keepdim = True) +q_Z_given_X = q / q.sum(1, keepdim = True) + +for x in range(p.size(0)): + elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() + print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))