50f9d10b0eaf57f0704bd0f4b5e02fee8a274bf9
[pytorch.git] / autoencoder.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: /home/fleuret/conda/bin/python
5 # @XREMOTE_PRE: killall -q -9 python || true
6 # @XREMOTE_PRE: ln -sf /home/fleuret/data/pytorch ./data
7 # @XREMOTE_GET: *.log *.dat *.png *.pth
8
9 import sys, argparse, os, time
10
11 import torch, torchvision
12
13 from torch import optim, nn
14 from torch.nn import functional as F
15
16 import torchvision
17
18 ######################################################################
19
20 if torch.cuda.is_available():
21     device = torch.device('cuda')
22 else:
23     device = torch.device('cpu')
24
25 ######################################################################
26
27 parser = argparse.ArgumentParser(description = 'Simple auto-encoder.')
28
29 parser.add_argument('--nb_epochs',
30                     type = int, default = 25)
31
32 parser.add_argument('--batch_size',
33                     type = int, default = 100)
34
35 parser.add_argument('--data_dir',
36                     type = str, default = './data/')
37
38 parser.add_argument('--log_filename',
39                     type = str, default = 'train.log')
40
41 parser.add_argument('--embedding_dim',
42                     type = int, default = 16)
43
44 parser.add_argument('--nb_channels',
45                     type = int, default = 32)
46
47 parser.add_argument('--force_train',
48                     type = bool, default = False)
49
50 args = parser.parse_args()
51
52 log_file = open(args.log_filename, 'w')
53
54 ######################################################################
55
56 def log_string(s, color = None):
57     t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
58
59     if log_file is not None:
60         log_file.write(t + s + '\n')
61         log_file.flush()
62
63     print(t + s)
64     sys.stdout.flush()
65
66 ######################################################################
67
68 class AutoEncoder(nn.Module):
69     def __init__(self, nb_channels, embedding_dim):
70         super(AutoEncoder, self).__init__()
71
72         self.encoder = nn.Sequential(
73             nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24
74             nn.ReLU(inplace = True),
75             nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20
76             nn.ReLU(inplace = True),
77             nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9
78             nn.ReLU(inplace = True),
79             nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4
80             nn.ReLU(inplace = True),
81             nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4)
82         )
83
84         self.decoder = nn.Sequential(
85             nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4),
86             nn.ReLU(inplace = True),
87             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4
88             nn.ReLU(inplace = True),
89             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9
90             nn.ReLU(inplace = True),
91             nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20
92             nn.ReLU(inplace = True),
93             nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24
94         )
95
96     def encode(self, x):
97         return self.encoder(x).view(x.size(0), -1)
98
99     def decode(self, z):
100         return self.decoder(z.view(z.size(0), -1, 1, 1))
101
102     def forward(self, x):
103         x = self.encoder(x)
104         # print(x.size())
105         x = self.decoder(x)
106         return x
107
108 ######################################################################
109
110 train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
111                                        train = True, download = True)
112 train_input = train_set.data.view(-1, 1, 28, 28).float()
113
114 test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
115                                       train = False, download = True)
116 test_input = test_set.data.view(-1, 1, 28, 28).float()
117
118 ######################################################################
119
120 train_input, test_input = train_input.to(device), test_input.to(device)
121
122 mu, std = train_input.mean(), train_input.std()
123 train_input.sub_(mu).div_(std)
124 test_input.sub_(mu).div_(std)
125
126 model = AutoEncoder(args.nb_channels, args.embedding_dim)
127 optimizer = optim.Adam(model.parameters(), lr = 1e-3)
128
129 model.to(device)
130
131 for epoch in range(args.nb_epochs):
132     acc_loss = 0
133     for input in train_input.split(args.batch_size):
134         input = input.to(device)
135         z = model.encode(input)
136         output = model.decode(z)
137         loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
138
139         optimizer.zero_grad()
140         loss.backward()
141         optimizer.step()
142
143         acc_loss += loss.item()
144
145     log_string(f'acc_loss {epoch} {acc_loss}', 'blue')
146
147 ######################################################################
148
149 input = test_input[:256]
150 z = model.encode(input)
151 output = model.decode(z)
152
153 torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8)
154 torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8)
155
156 ######################################################################
157
158 input = train_input[:256]
159 z = model.encode(input)
160 mu, std = z.mean(0), z.std(0)
161 z = z.normal_() * std + mu
162 output = model.decode(z)
163 torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8)
164
165 ######################################################################