Update.
[pytorch.git] / minidiffusion.py
index 2c54d19..65ca947 100755 (executable)
@@ -14,14 +14,20 @@ from torch import nn
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
+print(f'device {device}')
+
 ######################################################################
 
 def sample_gaussian_mixture(nb):
     p, std = 0.3, 0.2
-    result = torch.empty(nb, 1).normal_(0, std)
+    result = torch.randn(nb, 1) * std
     result = result + torch.sign(torch.rand(result.size()) - p) / 2
     return result
 
+def sample_ramp(nb):
+    result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
+    return result
+
 def sample_two_discs(nb):
     a = torch.rand(nb) * math.pi * 2
     b = torch.rand(nb).sqrt()
@@ -35,8 +41,9 @@ def sample_two_discs(nb):
 def sample_disc_grid(nb):
     a = torch.rand(nb) * math.pi * 2
     b = torch.rand(nb).sqrt()
-    q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
-    r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
+    N = 4
+    q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
+    r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
     b = b * 0.1
     result = torch.empty(nb, 2)
     result[:, 0] = a.cos() * b + q
@@ -59,6 +66,7 @@ def sample_mnist(nb):
 
 samplers = {
     'gaussian_mixture': sample_gaussian_mixture,
+    'ramp': sample_ramp,
     'two_discs': sample_two_discs,
     'disc_grid': sample_disc_grid,
     'spiral': sample_spiral,
@@ -97,7 +105,7 @@ parser.add_argument('--learning_rate',
 
 parser.add_argument('--ema_decay',
                     type = float, default = 0.9999,
-                    help = 'EMA decay, < 0 is no EMA')
+                    help = 'EMA decay, <= 0 is no EMA')
 
 data_list = ', '.join( [ str(k) for k in samplers ])
 
@@ -121,23 +129,20 @@ class EMA:
     def __init__(self, model, decay):
         self.model = model
         self.decay = decay
-        if self.decay < 0: return
-        self.ema = { }
+        self.mem = { }
         with torch.no_grad():
             for p in model.parameters():
-                self.ema[p] = p.clone()
+                self.mem[p] = p.clone()
 
     def step(self):
-        if self.decay < 0: return
         with torch.no_grad():
             for p in self.model.parameters():
-                self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
+                self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
 
-    def copy(self):
-        if self.decay < 0: return
+    def copy_to_model(self):
         with torch.no_grad():
             for p in self.model.parameters():
-                p.copy_(self.ema[p])
+                p.copy_(self.mem[p])
 
 ######################################################################
 
@@ -179,7 +184,7 @@ train_mean, train_std = train_input.mean(), train_input.std()
 # Model
 
 if train_input.dim() == 2:
-    nh = 64
+    nh = 256
 
     model = nn.Sequential(
         nn.Linear(train_input.size(1) + 1, nh),
@@ -197,6 +202,28 @@ elif train_input.dim() == 4:
 
 model.to(device)
 
+print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
+
+######################################################################
+# Generate
+
+def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
+    with torch.no_grad():
+
+        x = torch.randn(size, device = device)
+
+        for t in range(T-1, -1, -1):
+            z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
+            input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
+            x = 1/torch.sqrt(alpha[t]) \
+                * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
+                + sigma[t] * z
+
+        x = x * train_std + train_mean
+
+        return x
+
 ######################################################################
 # Train
 
@@ -206,7 +233,7 @@ alpha = 1 - beta
 alpha_bar = alpha.log().cumsum(0).exp()
 sigma = beta.sqrt()
 
-ema = EMA(model, decay = args.ema_decay)
+ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
 
 for k in range(args.nb_epochs):
 
@@ -217,8 +244,8 @@ for k in range(args.nb_epochs):
         x0 = (x0 - train_mean) / train_std
         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
         eps = torch.randn_like(x0)
-        input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
-        input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
+        xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
+        input = torch.cat((xt, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
         loss = (eps - model(input)).pow(2).mean()
         acc_loss += loss.item() * x0.size(0)
 
@@ -226,29 +253,11 @@ for k in range(args.nb_epochs):
         loss.backward()
         optimizer.step()
 
-        ema.step()
+        if ema is not None: ema.step()
 
-    if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}')
+    print(f'{k} {acc_loss / train_input.size(0)}')
 
-ema.copy()
-
-######################################################################
-# Generate
-
-def generate(size, model):
-    with torch.no_grad():
-        x = torch.randn(size, device = device)
-
-        for t in range(T-1, -1, -1):
-            z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
-            input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
-            x = 1/torch.sqrt(alpha[t]) \
-                * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
-                + sigma[t] * z
-
-        x = x * train_std + train_mean
-
-        return x
+if ema is not None: ema.copy_to_model()
 
 ######################################################################
 # Plot
@@ -256,14 +265,19 @@ def generate(size, model):
 model.eval()
 
 if train_input.dim() == 2:
+
     fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1)
 
+    # Nx1 -> histogram
     if train_input.size(1) == 1:
 
-        x = generate((10000, 1), model)
+        x = generate((10000, 1), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.25, 1.25)
+        ax.spines.right.set_visible(False)
+        ax.spines.top.set_visible(False)
 
         d = train_input.flatten().detach().to('cpu').numpy()
         ax.hist(d, 25, (-1, 1),
@@ -277,21 +291,25 @@ if train_input.dim() == 2:
 
         ax.legend(frameon = False, loc = 2)
 
+    # Nx2 -> scatter plot
     elif train_input.size(1) == 2:
 
-        x = generate((1000, 2), model)
+        x = generate((1000, 2), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
-        ax.set_xlim(-1.25, 1.25)
-        ax.set_ylim(-1.25, 1.25)
+        ax.set_xlim(-1.5, 1.5)
+        ax.set_ylim(-1.5, 1.5)
         ax.set(aspect = 1)
+        ax.spines.right.set_visible(False)
+        ax.spines.top.set_visible(False)
 
-        d = train_input[:x.size(0)].detach().to('cpu').numpy()
+        d = x.detach().to('cpu').numpy()
         ax.scatter(d[:, 0], d[:, 1],
-                   color = 'lightblue', label = 'Train')
+                   s = 2.0, color = 'red', label = 'Synthesis')
 
-        d = x.detach().to('cpu').numpy()
+        d = train_input[:x.size(0)].detach().to('cpu').numpy()
         ax.scatter(d[:, 0], d[:, 1],
-                   facecolors = 'none', color = 'red', label = 'Synthesis')
+                   s = 2.0, color = 'gray', label = 'Train')
 
         ax.legend(frameon = False, loc = 2)
 
@@ -303,8 +321,11 @@ if train_input.dim() == 2:
         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
         plt.show()
 
+# NxCxHxW -> image
 elif train_input.dim() == 4:
-    x = generate((128,) + train_input.size()[1:], model)
+
+    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
     x = 1 - x.clamp(min = 0, max = 255) / 255
     torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)