X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tiny_vae.py;h=405c103ca8304098e5380e85508274cc9e3606df;hb=ae1c9180165d9264e9cfe152bb64164926b5ddd2;hp=577f717e8c61d14cb98a5f9a99ae1af22e90cdbc;hpb=dc1e3534151307491a1eacf053fc4aede631448b;p=pytorch.git diff --git a/tiny_vae.py b/tiny_vae.py index 577f717..405c103 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -24,10 +24,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") +parser = argparse.ArgumentParser( + description="Very simple implementation of a VAE for teaching." +) parser.add_argument("--nb_epochs", type=int, default=100) +parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--batch_size", type=int, default=100) parser.add_argument("--data_dir", type=str, default="./data/") @@ -61,12 +65,14 @@ def log_string(s): ###################################################################### -def sample_gaussian(mu, log_var): +def sample_gaussian(param): + mu, log_var = param std = log_var.mul(0.5).exp() return torch.randn(mu.size(), device=mu.device) * std + mu -def log_p_gaussian(x, mu, log_var): +def log_p_gaussian(x, param): + mu, log_var = param var = log_var.exp() return ( (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi)) @@ -75,9 +81,9 @@ def log_p_gaussian(x, mu, log_var): ) -def dkl_gaussians(mean_a, log_var_a, mean_b, log_var_b): - mean_a, log_var_a = mean_a.flatten(1), log_var_a.flatten(1) - mean_b, log_var_b = mean_b.flatten(1), log_var_b.flatten(1) +def dkl_gaussians(param_a, param_b): + mean_a, log_var_a = param_a[0].flatten(1), param_a[1].flatten(1) + mean_b, log_var_b = param_b[0].flatten(1), param_b[1].flatten(1) var_a = log_var_a.exp() var_b = log_var_b.exp() return 0.5 * ( @@ -135,6 +141,7 @@ class ImageGivenLatentNet(nn.Module): def forward(self, z): output = self.model(z.view(z.size(0), -1, 1, 1)) mu, log_var = output[:, 0:1], output[:, 1:2] + # log_var.flatten(1)[...] = log_var.flatten(1)[:, :1] return mu, log_var @@ -160,7 +167,7 @@ model_p_X_given_z = ImageGivenLatentNet( optimizer = optim.Adam( itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()), - lr=4e-4, + lr=args.learning_rate, ) model_p_X_given_z.to(device) @@ -176,28 +183,26 @@ test_input.sub_(train_mu).div_(train_std) ###################################################################### -mean_p_Z = train_input.new_zeros(1, args.latent_dim) -log_var_p_Z = mean_p_Z +zeros = train_input.new_zeros(1, args.latent_dim) + +param_p_Z = zeros, zeros for epoch in range(args.nb_epochs): acc_loss = 0 for x in train_input.split(args.batch_size): - mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) - z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) - mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + param_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(param_q_Z_given_x) + param_p_X_given_z = model_p_X_given_z(z) + log_p_x_given_z = log_p_gaussian(x, param_p_X_given_z) if args.no_dkl: - log_q_z_given_x = log_p_gaussian(z, mean_q_Z_given_x, log_var_q_Z_given_x) - log_p_x_z = log_p_gaussian( - x, mean_p_X_given_z, log_var_p_X_given_z - ) + log_p_gaussian(z, mean_p_Z, log_var_p_Z) + log_q_z_given_x = log_p_gaussian(z, param_q_Z_given_x) + log_p_z = log_p_gaussian(z, param_p_Z) + log_p_x_z = log_p_x_given_z + log_p_x_z loss = -(log_p_x_z - log_q_z_given_x).mean() else: - log_p_x_given_z = log_p_gaussian(x, mean_p_X_given_z, log_var_p_X_given_z) - dkl_q_Z_given_x_from_p_Z = dkl_gaussians( - mean_q_Z_given_x, log_var_q_Z_given_x, mean_p_Z, log_var_p_Z - ) + dkl_q_Z_given_x_from_p_Z = dkl_gaussians(param_q_Z_given_x, param_p_Z) loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() optimizer.zero_grad() @@ -224,17 +229,21 @@ save_image(x, "input.png") # Save the same images after encoding / decoding -mean_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) -z = sample_gaussian(mean_q_Z_given_x, log_var_q_Z_given_x) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +param_q_Z_given_x = model_q_Z_given_x(x) +z = sample_gaussian(param_q_Z_given_x) +param_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(param_p_X_given_z) save_image(x, "output.png") +save_image(param_p_X_given_z[0], "output_mean.png") # Generate a bunch of images -z = sample_gaussian(mean_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) -mean_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) -x = sample_gaussian(mean_p_X_given_z, log_var_p_X_given_z) +z = sample_gaussian( + (param_p_Z[0].expand(x.size(0), -1), param_p_Z[1].expand(x.size(0), -1)) +) +param_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(param_p_X_given_z) save_image(x, "synth.png") +save_image(param_p_X_given_z[0], "synth_mean.png") ######################################################################