Update.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 14 Aug 2022 13:59:51 +0000 (15:59 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 14 Aug 2022 13:59:51 +0000 (15:59 +0200)
minidiffusion.py

index e1f6abd..841dd2a 100755 (executable)
@@ -207,8 +207,10 @@ print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
 ######################################################################
 # Generate
 
-def generate(size, alpha, alpha_bar, sigma, model):
+def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
     with torch.no_grad():
+
         x = torch.randn(size, device = device)
 
         for t in range(T-1, -1, -1):
@@ -269,7 +271,8 @@ if train_input.dim() == 2:
 
     if train_input.size(1) == 1:
 
-        x = generate((10000, 1), alpha, alpha_bar, sigma, model)
+        x = generate((10000, 1), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.25, 1.25)
         ax.spines.right.set_visible(False)
@@ -289,7 +292,8 @@ if train_input.dim() == 2:
 
     elif train_input.size(1) == 2:
 
-        x = generate((1000, 2), alpha, alpha_bar, sigma, model)
+        x = generate((1000, 2), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.5, 1.5)
         ax.set_ylim(-1.5, 1.5)
@@ -317,7 +321,8 @@ if train_input.dim() == 2:
 
 elif train_input.dim() == 4:
 
-    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, model)
+    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
     x = 1 - x.clamp(min = 0, max = 255) / 255
     torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)