Update.
[pytorch.git] / tinyae.py
index c608c9c..b4f3aba 100755 (executable)
--- a/tinyae.py
+++ b/tinyae.py
@@ -14,77 +14,75 @@ from torch.nn import functional as F
 
 ######################################################################
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
 
-parser = argparse.ArgumentParser(description = 'Tiny LeNet-like auto-encoder.')
+parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
 
-parser.add_argument('--nb_epochs',
-                    type = int, default = 25)
+parser.add_argument("--nb_epochs", type=int, default=25)
 
-parser.add_argument('--batch_size',
-                    type = int, default = 100)
+parser.add_argument("--batch_size", type=int, default=100)
 
-parser.add_argument('--data_dir',
-                    type = str, default = './data/')
+parser.add_argument("--data_dir", type=str, default="./data/")
 
-parser.add_argument('--log_filename',
-                    type = str, default = 'train.log')
+parser.add_argument("--log_filename", type=str, default="train.log")
 
-parser.add_argument('--embedding_dim',
-                    type = int, default = 8)
+parser.add_argument("--embedding_dim", type=int, default=8)
 
-parser.add_argument('--nb_channels',
-                    type = int, default = 32)
-
-parser.add_argument('--force_train',
-                    type = bool, default = False)
+parser.add_argument("--nb_channels", type=int, default=32)
 
 args = parser.parse_args()
 
-log_file = open(args.log_filename, 'w')
+log_file = open(args.log_filename, "w")
 
 ######################################################################
 
+
 def log_string(s):
     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
 
     if log_file is not None:
-        log_file.write(t + s + '\n')
+        log_file.write(t + s + "\n")
         log_file.flush()
 
     print(t + s)
     sys.stdout.flush()
 
+
 ######################################################################
 
+
 class AutoEncoder(nn.Module):
     def __init__(self, nb_channels, embedding_dim):
-        super(AutoEncoder, self).__init__()
+        super().__init__()
 
         self.encoder = nn.Sequential(
-            nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24
-            nn.ReLU(inplace = True),
-            nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20
-            nn.ReLU(inplace = True),
-            nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9
-            nn.ReLU(inplace = True),
-            nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4
-            nn.ReLU(inplace = True),
-            nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4)
+            nn.Conv2d(1, nb_channels, kernel_size=5),  # to 24x24
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, embedding_dim, kernel_size=4),
         )
 
         self.decoder = nn.Sequential(
-            nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4),
-            nn.ReLU(inplace = True),
-            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4
-            nn.ReLU(inplace = True),
-            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9
-            nn.ReLU(inplace = True),
-            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20
-            nn.ReLU(inplace = True),
-            nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24
+            nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4),
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=3, stride=2
+            ),  # from 4x4
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=4, stride=2
+            ),  # from 9x9
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, 1, kernel_size=5),  # from 24x24
         )
 
     def encode(self, x):
@@ -98,20 +96,23 @@ class AutoEncoder(nn.Module):
         x = self.decoder(x)
         return x
 
+
 ######################################################################
 
-train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
-                                       train = True, download = True)
+train_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=True, download=True
+)
 train_input = train_set.data.view(-1, 1, 28, 28).float()
 
-test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
-                                      train = False, download = True)
+test_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=False, download=True
+)
 test_input = test_set.data.view(-1, 1, 28, 28).float()
 
 ######################################################################
 
 model = AutoEncoder(args.nb_channels, args.embedding_dim)
-optimizer = optim.Adam(model.parameters(), lr = 1e-3)
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
 
 model.to(device)
 
@@ -124,7 +125,6 @@ test_input.sub_(mu).div_(std)
 ######################################################################
 
 for epoch in range(args.nb_epochs):
-
     acc_loss = 0
 
     for input in train_input.split(args.batch_size):
@@ -137,7 +137,7 @@ for epoch in range(args.nb_epochs):
 
         acc_loss += loss.item()
 
-    log_string('acc_loss {:d} {:f}.'.format(epoch, acc_loss))
+    log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss))
 
 ######################################################################
 
@@ -148,8 +148,8 @@ input = test_input[:256]
 z = model.encode(input)
 output = model.decode(z)
 
-torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8)
-torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8)
+torchvision.utils.save_image(1 - input, "ae-input.png", nrow=16, pad_value=0.8)
+torchvision.utils.save_image(1 - output, "ae-output.png", nrow=16, pad_value=0.8)
 
 # Dumb synthesis
 
@@ -158,6 +158,6 @@ mu, std = z.mean(0), z.std(0)
 z = z.normal_() * std + mu
 output = model.decode(z)
 
-torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8)
+torchvision.utils.save_image(1 - output, "ae-synth.png", nrow=16, pad_value=0.8)
 
 ######################################################################