From c951f0b1b425dc91ba74e9cb75425b0ad2f481ac Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 1 Mar 2024 19:15:54 +0100 Subject: [PATCH] Update. --- picocrafter.py | 11 +-- tiny_vae.py | 240 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 5 deletions(-) create mode 100755 tiny_vae.py diff --git a/picocrafter.py b/picocrafter.py index 5bd6a48..23d93b2 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -23,8 +23,8 @@ # iterations. # # The environment is a rectangular area with walls "#" dispatched -# randomly. The agent "@" can perform five actions: move NESW or do -# not move. +# randomly. The agent "@" can perform five actions: move "NESW" or be +# immobile "I". # # There are monsters "$" moving randomly. The agent gets hit by every # monster present in one of the 4 direct neighborhoods at the end of @@ -39,8 +39,9 @@ # "B", "C"). The keys and vault can only be used in sequence: # initially the agent can move only to free spaces, or to the "a", in # which case the key is removed from the environment and the agent now -# carries it, and can move to free spaces or the "A". When it moves to -# the "A", it gets a reward, loses the "a", the "A" is removed from +# carries it, it appears in the inventory at the bottom of the frame, +# and the agent can now move to free spaces or the "A". When it moves +# to the "A", it gets a reward, loses the "a", the "A" is removed from # the environment, but the agent can now move to the "b", etc. Rewards # are 1 for "A" and "B" and 10 for "C". @@ -244,7 +245,7 @@ class PicroCrafterEnvironment: def action2str(self, n): if n >= 0 and n < 5: - return "XNESW"[n] + return "INESW"[n] else: return "?" diff --git a/tiny_vae.py b/tiny_vae.py new file mode 100755 index 0000000..bbdbf1a --- /dev/null +++ b/tiny_vae.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_GET: *.png + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import sys, os, argparse, time, math, itertools + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") + +parser.add_argument("--nb_epochs", type=int, default=25) + +parser.add_argument("--batch_size", type=int, default=100) + +parser.add_argument("--data_dir", type=str, default="./data/") + +parser.add_argument("--log_filename", type=str, default="train.log") + +parser.add_argument("--latent_dim", type=int, default=32) + +parser.add_argument("--nb_channels", type=int, default=128) + +parser.add_argument("--no_dkl", action="store_true") + +args = parser.parse_args() + +log_file = open(args.log_filename, "w") + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +###################################################################### + + +def sample_gaussian(mu, log_var): + std = log_var.mul(0.5).exp() + return torch.randn(mu.size(), device=mu.device) * std + mu + + +def log_p_gaussian(x, mu, log_var): + var = log_var.exp() + return ( + (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi)) + .flatten(1) + .sum(1) + ) + + +def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b): + mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1) + mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1) + var_a = log_var_a.exp() + var_b = log_var_b.exp() + return 0.5 * ( + log_var_b - log_var_a - 1 + (mu_a - mu_b).pow(2) / var_b + var_a / var_b + ).sum(1) + + +###################################################################### + + +class LatentGivenImageNet(nn.Module): + def __init__(self, nb_channels, latent_dim): + super().__init__() + + self.model = nn.Sequential( + nn.Conv2d(1, nb_channels, kernel_size=1), # to 28x28 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 24x24 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4), + ) + + def forward(self, x): + output = self.model(x).view(x.size(0), 2, -1) + mu, log_var = output[:, 0], output[:, 1] + return mu, log_var + + +class ImageGivenLatentNet(nn.Module): + def __init__(self, nb_channels, latent_dim): + super().__init__() + + self.model = nn.Sequential( + nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=3, stride=2 + ), # from 4x4 + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=4, stride=2 + ), # from 9x9 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, 2, kernel_size=5), # from 24x24 + ) + + def forward(self, z): + output = self.model(z.view(z.size(0), -1, 1, 1)) + mu, log_var = output[:, 0:1], output[:, 1:2] + return mu, log_var + + +###################################################################### + +data_dir = os.path.join(args.data_dir, "mnist") + +train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True) +train_input = train_set.data.view(-1, 1, 28, 28).float() + +test_set = torchvision.datasets.MNIST(data_dir, train=False, download=True) +test_input = test_set.data.view(-1, 1, 28, 28).float() + +###################################################################### + +model_q_Z_given_x = LatentGivenImageNet( + nb_channels=args.nb_channels, latent_dim=args.latent_dim +) + +model_p_X_given_z = ImageGivenLatentNet( + nb_channels=args.nb_channels, latent_dim=args.latent_dim +) + +optimizer = optim.Adam( + itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()), + lr=4e-4, +) + +model_p_X_given_z.to(device) +model_q_Z_given_x.to(device) + +###################################################################### + +train_input, test_input = train_input.to(device), test_input.to(device) + +train_mu, train_std = train_input.mean(), train_input.std() +train_input.sub_(train_mu).div_(train_std) +test_input.sub_(train_mu).div_(train_std) + +###################################################################### + +mu_p_Z = train_input.new_zeros(1, args.latent_dim) +log_var_p_Z = mu_p_Z + +for epoch in range(args.nb_epochs): + acc_loss = 0 + + for x in train_input.split(args.batch_size): + mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) + z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x) + mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) + + if args.no_dkl: + log_q_z_given_x = log_p_gaussian(z, mu_q_Z_given_x, log_var_q_Z_given_x) + log_p_x_z = log_p_gaussian( + x, mu_p_X_given_z, log_var_p_X_given_z + ) + log_p_gaussian(z, mu_p_Z, log_var_p_Z) + loss = -(log_p_x_z - log_q_z_given_x).mean() + else: + log_p_x_given_z = log_p_gaussian(x, mu_p_X_given_z, log_var_p_X_given_z) + dkl_q_Z_given_x_from_p_Z = dkl_gaussians( + mu_q_Z_given_x, log_var_q_Z_given_x, mu_p_Z, log_var_p_Z + ) + loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() * x.size(0) + + log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}") + +###################################################################### + + +def save_image(x, filename): + x = x * train_std + train_mu + x = x.clamp(min=0, max=255) / 255 + torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) + + +# Save a bunch of test images + +x = test_input[:256] +save_image(x, "input.png") + +# Save the same images after encoding / decoding + +mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x) +z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x) +mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z) +save_image(x, "output.png") + +# Generate a bunch of images + +z = sample_gaussian(mu_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1)) +mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z) +x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z) +save_image(x, "synth.png") + +###################################################################### -- 2.20.1