Initial commit.
[pytorch.git] / tinyae.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import sys, argparse, time
9
10 import torch, torchvision
11
12 from torch import optim, nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
19 ######################################################################
20
21 parser = argparse.ArgumentParser(description = 'Tiny LeNet-like auto-encoder.')
22
23 parser.add_argument('--nb_epochs',
24                     type = int, default = 25)
25
26 parser.add_argument('--batch_size',
27                     type = int, default = 100)
28
29 parser.add_argument('--data_dir',
30                     type = str, default = './data/')
31
32 parser.add_argument('--log_filename',
33                     type = str, default = 'train.log')
34
35 parser.add_argument('--embedding_dim',
36                     type = int, default = 8)
37
38 parser.add_argument('--nb_channels',
39                     type = int, default = 32)
40
41 parser.add_argument('--force_train',
42                     type = bool, default = False)
43
44 args = parser.parse_args()
45
46 log_file = open(args.log_filename, 'w')
47
48 ######################################################################
49
50 def log_string(s):
51     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
52
53     if log_file is not None:
54         log_file.write(t + s + '\n')
55         log_file.flush()
56
57     print(t + s)
58     sys.stdout.flush()
59
60 ######################################################################
61
62 class AutoEncoder(nn.Module):
63     def __init__(self, nb_channels, embedding_dim):
64         super(AutoEncoder, self).__init__()
65
66         self.encoder = nn.Sequential(
67             nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24
68             nn.ReLU(inplace = True),
69             nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20
70             nn.ReLU(inplace = True),
71             nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9
72             nn.ReLU(inplace = True),
73             nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4
74             nn.ReLU(inplace = True),
75             nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4)
76         )
77
78         self.decoder = nn.Sequential(
79             nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4),
80             nn.ReLU(inplace = True),
81             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4
82             nn.ReLU(inplace = True),
83             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9
84             nn.ReLU(inplace = True),
85             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20
86             nn.ReLU(inplace = True),
87             nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24
88         )
89
90     def encode(self, x):
91         return self.encoder(x).view(x.size(0), -1)
92
93     def decode(self, z):
94         return self.decoder(z.view(z.size(0), -1, 1, 1))
95
96     def forward(self, x):
97         x = self.encoder(x)
98         x = self.decoder(x)
99         return x
100
101 ######################################################################
102
103 train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
104                                        train = True, download = True)
105 train_input = train_set.data.view(-1, 1, 28, 28).float()
106
107 test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
108                                       train = False, download = True)
109 test_input = test_set.data.view(-1, 1, 28, 28).float()
110
111 ######################################################################
112
113 model = AutoEncoder(args.nb_channels, args.embedding_dim)
114 optimizer = optim.Adam(model.parameters(), lr = 1e-3)
115
116 model.to(device)
117
118 train_input, test_input = train_input.to(device), test_input.to(device)
119
120 mu, std = train_input.mean(), train_input.std()
121 train_input.sub_(mu).div_(std)
122 test_input.sub_(mu).div_(std)
123
124 ######################################################################
125
126 for epoch in range(args.nb_epochs):
127
128     acc_loss = 0
129
130     for input in train_input.split(args.batch_size):
131         output = model(input)
132         loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
133
134         optimizer.zero_grad()
135         loss.backward()
136         optimizer.step()
137
138         acc_loss += loss.item()
139
140     log_string('acc_loss {:d} {:f}.'.format(epoch, acc_loss))
141
142 ######################################################################
143
144 input = test_input[:256]
145
146 # Encode / decode
147
148 z = model.encode(input)
149 output = model.decode(z)
150
151 torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8)
152 torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8)
153
154 # Dumb synthesis
155
156 z = model.encode(input)
157 mu, std = z.mean(0), z.std(0)
158 z = z.normal_() * std + mu
159 output = model.decode(z)
160
161 torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8)
162
163 ######################################################################