066cbbbe1fa365458ea163bea0e64e1cd2787c74
[pytorch.git] / minidiffusion.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, argparse
9
10 import matplotlib.pyplot as plt
11
12 import torch, torchvision
13 from torch import nn
14 from torch.nn import functional as F
15
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
18 print(f'device {device}')
19
20 ######################################################################
21
22 def sample_gaussian_mixture(nb):
23     p, std = 0.3, 0.2
24     result = torch.randn(nb, 1) * std
25     result = result + torch.sign(torch.rand(result.size()) - p) / 2
26     return result
27
28 def sample_ramp(nb):
29     result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
30     return result
31
32 def sample_two_discs(nb):
33     a = torch.rand(nb) * math.pi * 2
34     b = torch.rand(nb).sqrt()
35     q = (torch.rand(nb) <= 0.5).long()
36     b = b * (0.3 + 0.2 * q)
37     result = torch.empty(nb, 2)
38     result[:, 0] = a.cos() * b - 0.5 + q
39     result[:, 1] = a.sin() * b - 0.5 + q
40     return result
41
42 def sample_disc_grid(nb):
43     a = torch.rand(nb) * math.pi * 2
44     b = torch.rand(nb).sqrt()
45     N = 4
46     q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
47     r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
48     b = b * 0.1
49     result = torch.empty(nb, 2)
50     result[:, 0] = a.cos() * b + q
51     result[:, 1] = a.sin() * b + r
52     return result
53
54 def sample_spiral(nb):
55     u = torch.rand(nb)
56     rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15
57     theta = u * math.pi * 3
58     result = torch.empty(nb, 2)
59     result[:, 0] = theta.cos() * rho
60     result[:, 1] = theta.sin() * rho
61     return result
62
63 def sample_mnist(nb):
64     train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True)
65     result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float()
66     return result
67
68 samplers = {
69     f.__name__.removeprefix('sample_') : f for f in [
70         sample_gaussian_mixture,
71         sample_ramp,
72         sample_two_discs,
73         sample_disc_grid,
74         sample_spiral,
75         sample_mnist,
76     ]
77 }
78
79 ######################################################################
80
81 parser = argparse.ArgumentParser(
82     description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
83 "Denoising Diffusion Probabilistic Models" (2020)
84 https://arxiv.org/abs/2006.11239''',
85
86     formatter_class = argparse.ArgumentDefaultsHelpFormatter
87 )
88
89 parser.add_argument('--seed',
90                     type = int, default = 0,
91                     help = 'Random seed, < 0 is no seeding')
92
93 parser.add_argument('--nb_epochs',
94                     type = int, default = 100,
95                     help = 'How many epochs')
96
97 parser.add_argument('--batch_size',
98                     type = int, default = 25,
99                     help = 'Batch size')
100
101 parser.add_argument('--nb_samples',
102                     type = int, default = 25000,
103                     help = 'Number of training examples')
104
105 parser.add_argument('--learning_rate',
106                     type = float, default = 1e-3,
107                     help = 'Learning rate')
108
109 parser.add_argument('--ema_decay',
110                     type = float, default = 0.9999,
111                     help = 'EMA decay, <= 0 is no EMA')
112
113 data_list = ', '.join( [ str(k) for k in samplers ])
114
115 parser.add_argument('--data',
116                     type = str, default = 'gaussian_mixture',
117                     help = f'Toy data-set to use: {data_list}')
118
119 parser.add_argument('--no_window',
120                     action='store_true', default = False)
121
122 args = parser.parse_args()
123
124 if args.seed >= 0:
125     # torch.backends.cudnn.deterministic = True
126     # torch.backends.cudnn.benchmark = False
127     # torch.use_deterministic_algorithms(True)
128     torch.manual_seed(args.seed)
129     if torch.cuda.is_available():
130         torch.cuda.manual_seed_all(args.seed)
131
132 ######################################################################
133
134 class EMA:
135     def __init__(self, model, decay):
136         self.model = model
137         self.decay = decay
138         self.mem = { }
139         with torch.no_grad():
140             for p in model.parameters():
141                 self.mem[p] = p.clone()
142
143     def step(self):
144         with torch.no_grad():
145             for p in self.model.parameters():
146                 self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
147
148     def copy_to_model(self):
149         with torch.no_grad():
150             for p in self.model.parameters():
151                 p.copy_(self.mem[p])
152
153 ######################################################################
154
155 # Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
156 # additional dimension / channel
157
158 class TimeAppender(nn.Module):
159     def __init__(self):
160         super().__init__()
161
162     def forward(self, u):
163         x, t = u
164         if not torch.is_tensor(t):
165             t = x.new_full((x.size(0),), t)
166         t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
167         return torch.cat((x, t), 1)
168
169 class ConvNet(nn.Module):
170     def __init__(self, in_channels, out_channels):
171         super().__init__()
172
173         ks, nc = 5, 64
174
175         self.core = nn.Sequential(
176             TimeAppender(),
177             nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
178             nn.ReLU(),
179             nn.Conv2d(nc, nc, ks, padding = ks//2),
180             nn.ReLU(),
181             nn.Conv2d(nc, nc, ks, padding = ks//2),
182             nn.ReLU(),
183             nn.Conv2d(nc, nc, ks, padding = ks//2),
184             nn.ReLU(),
185             nn.Conv2d(nc, nc, ks, padding = ks//2),
186             nn.ReLU(),
187             nn.Conv2d(nc, out_channels, ks, padding = ks//2),
188         )
189
190     def forward(self, u):
191         return self.core(u)
192
193 ######################################################################
194 # Data
195
196 try:
197     train_input = samplers[args.data](args.nb_samples).to(device)
198 except KeyError:
199     print(f'unknown data {args.data}')
200     exit(1)
201
202 train_mean, train_std = train_input.mean(), train_input.std()
203
204 ######################################################################
205 # Model
206
207 if train_input.dim() == 2:
208     nh = 256
209
210     model = nn.Sequential(
211         TimeAppender(),
212         nn.Linear(train_input.size(1) + 1, nh),
213         nn.ReLU(),
214         nn.Linear(nh, nh),
215         nn.ReLU(),
216         nn.Linear(nh, nh),
217         nn.ReLU(),
218         nn.Linear(nh, train_input.size(1)),
219     )
220
221 elif train_input.dim() == 4:
222
223     model = ConvNet(train_input.size(1), train_input.size(1))
224
225 model.to(device)
226
227 print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
228
229 ######################################################################
230 # Generate
231
232 def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
233
234     with torch.no_grad():
235
236         x = torch.randn(size, device = device)
237
238         for t in range(T-1, -1, -1):
239             output = model((x, t / (T - 1) - 0.5))
240             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
241             x = 1/torch.sqrt(alpha[t]) \
242                 * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
243                 + sigma[t] * z
244
245         x = x * train_std + train_mean
246
247         return x
248
249 ######################################################################
250 # Train
251
252 T = 1000
253 beta = torch.linspace(1e-4, 0.02, T, device = device)
254 alpha = 1 - beta
255 alpha_bar = alpha.log().cumsum(0).exp()
256 sigma = beta.sqrt()
257
258 ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
259
260 for k in range(args.nb_epochs):
261
262     acc_loss = 0
263     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
264
265     for x0 in train_input.split(args.batch_size):
266         x0 = (x0 - train_mean) / train_std
267         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
268         eps = torch.randn_like(x0)
269         xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
270         output = model((xt, t / (T - 1) - 0.5))
271         loss = (eps - output).pow(2).mean()
272         acc_loss += loss.item() * x0.size(0)
273
274         optimizer.zero_grad()
275         loss.backward()
276         optimizer.step()
277
278         if ema is not None: ema.step()
279
280     print(f'{k} {acc_loss / train_input.size(0)}')
281
282 if ema is not None: ema.copy_to_model()
283
284 ######################################################################
285 # Plot
286
287 model.eval()
288
289 ########################################
290 # Nx1 -> histogram
291 if train_input.dim() == 2 and train_input.size(1) == 1:
292
293     fig = plt.figure()
294     fig.set_figheight(5)
295     fig.set_figwidth(8)
296
297     ax = fig.add_subplot(1, 1, 1)
298
299     x = generate((10000, 1), T, alpha, alpha_bar, sigma,
300                  model, train_mean, train_std)
301
302     ax.set_xlim(-1.25, 1.25)
303     ax.spines.right.set_visible(False)
304     ax.spines.top.set_visible(False)
305
306     d = train_input.flatten().detach().to('cpu').numpy()
307     ax.hist(d, 25, (-1, 1),
308             density = True,
309             histtype = 'bar', edgecolor = 'white', color = 'lightblue', label = 'Train')
310
311     d = x.flatten().detach().to('cpu').numpy()
312     ax.hist(d, 25, (-1, 1),
313             density = True,
314             histtype = 'step', color = 'red', label = 'Synthesis')
315
316     ax.legend(frameon = False, loc = 2)
317
318     filename = f'minidiffusion_{args.data}.pdf'
319     print(f'saving {filename}')
320     fig.savefig(filename, bbox_inches='tight')
321
322     if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
323         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
324         plt.show()
325
326 ########################################
327 # Nx2 -> scatter plot
328 elif train_input.dim() == 2 and train_input.size(1) == 2:
329
330     fig = plt.figure()
331     fig.set_figheight(6)
332     fig.set_figwidth(6)
333
334     ax = fig.add_subplot(1, 1, 1)
335
336     x = generate((1000, 2), T, alpha, alpha_bar, sigma,
337                  model, train_mean, train_std)
338
339     ax.set_xlim(-1.5, 1.5)
340     ax.set_ylim(-1.5, 1.5)
341     ax.set(aspect = 1)
342     ax.spines.right.set_visible(False)
343     ax.spines.top.set_visible(False)
344
345     d = train_input[:x.size(0)].detach().to('cpu').numpy()
346     ax.scatter(d[:, 0], d[:, 1],
347                s = 2.5, color = 'gray', label = 'Train')
348
349     d = x.detach().to('cpu').numpy()
350     ax.scatter(d[:, 0], d[:, 1],
351                s = 2.0, color = 'red', label = 'Synthesis')
352
353     ax.legend(frameon = False, loc = 2)
354
355     filename = f'minidiffusion_{args.data}.pdf'
356     print(f'saving {filename}')
357     fig.savefig(filename, bbox_inches='tight')
358
359     if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
360         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
361         plt.show()
362
363 ########################################
364 # NxCxHxW -> image
365 elif train_input.dim() == 4:
366
367     x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma,
368                  model, train_mean, train_std)
369
370     x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255),
371                                     nrow = 16, padding = 1, pad_value = 64)
372     x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None]
373
374     t = torchvision.utils.make_grid(train_input[:128],
375                                     nrow = 16, padding = 1, pad_value = 64)
376     t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None]
377
378     result = 1 - torch.cat((t, x), 2) / 255
379
380     filename = f'minidiffusion_{args.data}.png'
381     print(f'saving {filename}')
382     torchvision.utils.save_image(result, filename)
383
384 else:
385
386     print(f'cannot plot result of size {train_input.size()}')
387
388 ######################################################################