Update.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 12 Aug 2022 07:57:09 +0000 (09:57 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 12 Aug 2022 07:57:09 +0000 (09:57 +0200)
minidiffusion.py

index ad1cda0..037ef11 100755 (executable)
@@ -5,6 +5,11 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel
+# "Denoising Diffusion Probabilistic Models" (2020)
+#
+# https://arxiv.org/abs/2006.11239
+
 import matplotlib.pyplot as plt
 import torch
 from torch import nn
@@ -62,7 +67,7 @@ for k in range(nb_epochs):
     if k%10 == 0: print(k, loss.item())
 
 ######################################################################
-# Plot
+# Generate
 
 x = torch.randn(10000, 1)
 
@@ -71,19 +76,27 @@ for t in range(T-1, -1, -1):
     input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1)
     x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z
 
+######################################################################
+# Plot
+
 fig = plt.figure()
 ax = fig.add_subplot(1, 1, 1)
 ax.set_xlim(-1.25, 1.25)
 
 d = train_input.flatten().detach().numpy()
-ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train')
+ax.hist(d, 25, (-1, 1),
+        density = True,
+        histtype = 'stepfilled', color = 'lightblue', label = 'Train')
 
 d = x.flatten().detach().numpy()
-ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis')
+ax.hist(d, 25, (-1, 1),
+        density = True,
+        histtype = 'step', color = 'red', label = 'Synthesis')
 
 ax.legend(frameon = False, loc = 2)
 
 filename = 'diffusion.pdf'
+print(f'saving {filename}')
 fig.savefig(filename, bbox_inches='tight')
 
 plt.show()