Update.
[pytorch.git] / elbo.py
diff --git a/elbo.py b/elbo.py
index 6af4a77..dbea3b5 100755 (executable)
--- a/elbo.py
+++ b/elbo.py
@@ -7,8 +7,10 @@
 
 import torch
 
+
 def D_KL(a, b):
-    return - a @ (b / a).log()
+    return -a @ (b / a).log()
+
 
 # p(X = x, Z = z) = p[x, z]
 
@@ -19,12 +21,12 @@ q_XZ /= q_XZ.sum()
 
 p_X = p_XZ.sum(1)
 p_Z = p_XZ.sum(0)
-p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim = True)
-p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim = True)
+p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim=True)
+p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim=True)
 
-#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True)
-q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim = True)
+# q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True)
+q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim=True)
 
 for x in range(p_XZ.size(0)):
-    elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log()
+    elbo = q_Z_given_X[x, :] @ (p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log()
     print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :]))