Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 3 Mar 2024 07:23:00 +0000 (08:23 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 3 Mar 2024 07:23:00 +0000 (08:23 +0100)
tiny_vae.py

index 784f775..b81df9a 100755 (executable)
@@ -65,12 +65,14 @@ def log_string(s):
 ######################################################################
 
 
-def sample_gaussian(mu, log_var):
+def sample_gaussian(param):
+    mu, log_var = param
     std = log_var.mul(0.5).exp()
     return torch.randn(mu.size(), device=mu.device) * std + mu
 
 
-def log_p_gaussian(x, mu, log_var):
+def log_p_gaussian(x, param):
+    mu, log_var = param
     var = log_var.exp()
     return (
         (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
@@ -79,9 +81,9 @@ def log_p_gaussian(x, mu, log_var):
     )
 
 
-def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b):
-    mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1)
-    mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1)
+def dkl_gaussians(param_a, param_b):
+    mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1)
+    mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1)
     var_a = log_var_a.exp()
     var_b = log_var_b.exp()
     return 0.5 * (
@@ -181,28 +183,26 @@ test_input.sub_(train_mu).div_(train_std)
 
 ######################################################################
 
-mean_p_Z = train_input.new_zeros(1, args.latent_dim)
-log_var_p_Z = mean_p_Z
+zeros = train_input.new_zeros(1, args.latent_dim)
+
+param_p_Z = zeros, zeros
 
 for epoch in range(args.nb_epochs):
     acc_loss = 0
 
     for x in train_input.split(args.batch_size):
-        mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-        z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
-        mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+        param_q_Z_given_x = model_q_Z_given_x(x)
+        z = sample_gaussian(param_q_Z_given_x)
+        param_p_X_given_z = model_p_X_given_z(z)
+        log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z)
 
         if args.no_dkl:
-            log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x)
-            log_p_x_z = log_p_gaussian(
-                x, mean_p_X_given_z, log_var_p_X_given_z
-            ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z)
+            log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x)
+            log_p_z = log_p_gaussian(z, param_p_Z)
+            log_p_x_z = log_p_x_given_z + log_p_x_z
             loss = -(log_p_x_z - log_q_z_given_x).mean()
         else:
-            log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z)
-            dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
-                mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z
-            )
+            dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z)
             loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
 
         optimizer.zero_grad()
@@ -229,17 +229,19 @@ save_image(x, "input.png")
 
 # Save the same images after encoding / decoding
 
-mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
-z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x)
-mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+param_q_Z_given_x = model_q_Z_given_x(x)
+z = sample_gaussian(param_q_Z_given_x)
+param_p_X_given_z = model_p_X_given_z(z)
+x = sample_gaussian(param_p_X_given_z)
 save_image(x, "output.png")
 
 # Generate a bunch of images
 
-z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
-mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z)
+z = sample_gaussian(
+    param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)
+)
+param_p_X_given_z = model_p_X_given_z(z)
+x = sample_gaussian(param_p_X_given_z)
 save_image(x, "synth.png")
 
 ######################################################################