Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 4 Mar 2024 06:20:22 +0000 (07:20 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 4 Mar 2024 06:20:22 +0000 (07:20 +0100)
tiny_vae.py

index 405c103..10ce19f 100755 (executable)
@@ -30,7 +30,7 @@ parser = argparse.ArgumentParser(
 
 parser.add_argument("--nb_epochs", type=int, default=100)
 
-parser.add_argument("--learning_rate", type=float, default=2e-4)
+parser.add_argument("--learning_rate", type=float, default=1e-4)
 
 parser.add_argument("--batch_size", type=int, default=100)
 
@@ -40,10 +40,12 @@ parser.add_argument("--log_filename", type=str, default="train.log")
 
 parser.add_argument("--latent_dim", type=int, default=32)
 
-parser.add_argument("--nb_channels", type=int, default=128)
+parser.add_argument("--nb_channels", type=int, default=64)
 
 parser.add_argument("--no_dkl", action="store_true")
 
+parser.add_argument("--beta", type=float, default=1.0)
+
 args = parser.parse_args()
 
 log_file = open(args.log_filename, "w")
@@ -157,6 +159,55 @@ test_input = test_set.data.view(-1, 1, 28, 28).float()
 
 ######################################################################
 
+
+def save_images(model_q_Z_given_x, model_p_X_given_z, prefix=""):
+    def save_image(x, filename):
+        x = x * train_std + train_mu
+        x = x.clamp(min=0, max=255) / 255
+        torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+        log_string(f"wrote {filename}")
+
+    # Save a bunch of train images
+
+    x = train_input[:256]
+    save_image(x, f"{prefix}train_input.png")
+
+    # Save the same images after encoding / decoding
+
+    param_q_Z_given_x = model_q_Z_given_x(x)
+    z = sample_gaussian(param_q_Z_given_x)
+    param_p_X_given_z = model_p_X_given_z(z)
+    x = sample_gaussian(param_p_X_given_z)
+    save_image(x, f"{prefix}train_output.png")
+    save_image(param_p_X_given_z[0], f"{prefix}train_output_mean.png")
+
+    # Save a bunch of test images
+
+    x = test_input[:256]
+    save_image(x, f"{prefix}input.png")
+
+    # Save the same images after encoding / decoding
+
+    param_q_Z_given_x = model_q_Z_given_x(x)
+    z = sample_gaussian(param_q_Z_given_x)
+    param_p_X_given_z = model_p_X_given_z(z)
+    x = sample_gaussian(param_p_X_given_z)
+    save_image(x, f"{prefix}output.png")
+    save_image(param_p_X_given_z[0], f"{prefix}output_mean.png")
+
+    # Generate a bunch of images
+
+    z = sample_gaussian(
+        (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
+    )
+    param_p_X_given_z = model_p_X_given_z(z)
+    x = sample_gaussian(param_p_X_given_z)
+    save_image(x, f"{prefix}synth.png")
+    save_image(param_p_X_given_z[0], f"{prefix}synth_mean.png")
+
+
+######################################################################
+
 model_q_Z_given_x = LatentGivenImageNet(
     nb_channels=args.nb_channels, latent_dim=args.latent_dim
 )
@@ -187,7 +238,7 @@ zeros = train_input.new_zeros(1, args.latent_dim)
 
 param_p_Z = zeros, zeros
 
-for epoch in range(args.nb_epochs):
+for n_epoch in range(args.nb_epochs):
     acc_loss = 0
 
     for x in train_input.split(args.batch_size):
@@ -203,7 +254,7 @@ for epoch in range(args.nb_epochs):
             loss = -(log_p_x_z - log_q_z_given_x).mean()
         else:
             dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z)
-            loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
+            loss = -(log_p_x_given_z - args.beta * dkl_q_Z_given_x_from_p_Z).mean()
 
         optimizer.zero_grad()
         loss.backward()
@@ -211,39 +262,9 @@ for epoch in range(args.nb_epochs):
 
         acc_loss += loss.item() * x.size(0)
 
-    log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
-
-######################################################################
-
-
-def save_image(x, filename):
-    x = x * train_std + train_mu
-    x = x.clamp(min=0, max=255) / 255
-    torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+    log_string(f"acc_loss {n_epoch} {acc_loss/train_input.size(0)}")
 
-
-# Save a bunch of test images
-
-x = test_input[:256]
-save_image(x, "input.png")
-
-# Save the same images after encoding / decoding
-
-param_q_Z_given_x = model_q_Z_given_x(x)
-z = sample_gaussian(param_q_Z_given_x)
-param_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(param_p_X_given_z)
-save_image(x, "output.png")
-save_image(param_p_X_given_z[0], "output_mean.png")
-
-# Generate a bunch of images
-
-z = sample_gaussian(
-    (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1))
-)
-param_p_X_given_z = model_p_X_given_z(z)
-x = sample_gaussian(param_p_X_given_z)
-save_image(x, "synth.png")
-save_image(param_p_X_given_z[0], "synth_mean.png")
+    if (n_epoch + 1) % 25 == 0:
+        save_images(model_q_Z_given_x, model_p_X_given_z, f"epoch_{n_epoch+1:04d}_")
 
 ######################################################################