Initial commit.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 7 Dec 2021 07:19:33 +0000 (08:19 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 7 Dec 2021 07:19:33 +0000 (08:19 +0100)
tinyae.py [new file with mode: 0755]

diff --git a/tinyae.py b/tinyae.py
new file mode 100755 (executable)
index 0000000..c608c9c
--- /dev/null
+++ b/tinyae.py
@@ -0,0 +1,163 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import sys, argparse, time
+
+import torch, torchvision
+
+from torch import optim, nn
+from torch.nn import functional as F
+
+######################################################################
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+######################################################################
+
+parser = argparse.ArgumentParser(description = 'Tiny LeNet-like auto-encoder.')
+
+parser.add_argument('--nb_epochs',
+                    type = int, default = 25)
+
+parser.add_argument('--batch_size',
+                    type = int, default = 100)
+
+parser.add_argument('--data_dir',
+                    type = str, default = './data/')
+
+parser.add_argument('--log_filename',
+                    type = str, default = 'train.log')
+
+parser.add_argument('--embedding_dim',
+                    type = int, default = 8)
+
+parser.add_argument('--nb_channels',
+                    type = int, default = 32)
+
+parser.add_argument('--force_train',
+                    type = bool, default = False)
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, 'w')
+
+######################################################################
+
+def log_string(s):
+    t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + '\n')
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+######################################################################
+
+class AutoEncoder(nn.Module):
+    def __init__(self, nb_channels, embedding_dim):
+        super(AutoEncoder, self).__init__()
+
+        self.encoder = nn.Sequential(
+            nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24
+            nn.ReLU(inplace = True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20
+            nn.ReLU(inplace = True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9
+            nn.ReLU(inplace = True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4
+            nn.ReLU(inplace = True),
+            nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4)
+        )
+
+        self.decoder = nn.Sequential(
+            nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4),
+            nn.ReLU(inplace = True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4
+            nn.ReLU(inplace = True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9
+            nn.ReLU(inplace = True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20
+            nn.ReLU(inplace = True),
+            nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24
+        )
+
+    def encode(self, x):
+        return self.encoder(x).view(x.size(0), -1)
+
+    def decode(self, z):
+        return self.decoder(z.view(z.size(0), -1, 1, 1))
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
+                                       train = True, download = True)
+train_input = train_set.data.view(-1, 1, 28, 28).float()
+
+test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
+                                      train = False, download = True)
+test_input = test_set.data.view(-1, 1, 28, 28).float()
+
+######################################################################
+
+model = AutoEncoder(args.nb_channels, args.embedding_dim)
+optimizer = optim.Adam(model.parameters(), lr = 1e-3)
+
+model.to(device)
+
+train_input, test_input = train_input.to(device), test_input.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+for epoch in range(args.nb_epochs):
+
+    acc_loss = 0
+
+    for input in train_input.split(args.batch_size):
+        output = model(input)
+        loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        acc_loss += loss.item()
+
+    log_string('acc_loss {:d} {:f}.'.format(epoch, acc_loss))
+
+######################################################################
+
+input = test_input[:256]
+
+# Encode / decode
+
+z = model.encode(input)
+output = model.decode(z)
+
+torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8)
+torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8)
+
+# Dumb synthesis
+
+z = model.encode(input)
+mu, std = z.mean(0), z.std(0)
+z = z.normal_() * std + mu
+output = model.decode(z)
+
+torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8)
+
+######################################################################