Initial commit.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 18 Dec 2019 15:51:34 +0000 (16:51 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 18 Dec 2019 15:51:34 +0000 (16:51 +0100)
denoising-ae-field.py [new file with mode: 0755]

diff --git a/denoising-ae-field.py b/denoising-ae-field.py
new file mode 100755 (executable)
index 0000000..175f344
--- /dev/null
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+model = nn.Sequential(
+    nn.Linear(2, 100),
+    nn.ReLU(),
+    nn.Linear(100, 2)
+)
+
+############################################################
+
+def data_zigzag(nb):
+    a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
+    # zigzag
+    x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
+    y = a * 2.5 - 1.25
+    data = torch.cat((y, x), 1)
+    data = data @ torch.tensor([[1., -1.], [1., 1.]])
+    return data
+
+def data_spiral(nb):
+    a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
+    x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
+    y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
+    data = torch.cat((y, x), 1)
+    return data
+
+######################################################################
+
+data = data_spiral(1000)
+# data = data_zigzag(1000)
+
+data = data - data.mean(0)
+
+batch_size, nb_epochs = 100, 1000
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+criterion = nn.MSELoss()
+
+for e in range(nb_epochs):
+    acc_loss = 0
+    for input in data.split(batch_size):
+        noise = input.new(input.size()).normal_(0, 0.1)
+        output = model(input + noise)
+        loss = criterion(output, input)
+        acc_loss += loss.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    if (e+1)%10 == 0: print(e+1, acc_loss)
+
+######################################################################
+
+a = torch.linspace(-1.5, 1.5, 30)
+x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
+y = a.view(-1,  1, 1).expand(a.size(0), a.size(0), 1)
+grid = torch.cat((y, x), 2).view(-1, 2)
+
+# Take the origins of the arrows on the part of grid closer than 0.1
+# from the data points
+dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
+origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
+
+field = model(origins).detach() - origins
+
+######################################################################
+
+import matplotlib.pyplot as plt
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+
+ax.axis('off')
+ax.set_xlim(-1.6, 1.6)
+ax.set_ylim(-1.6, 1.6)
+ax.set_aspect(1)
+
+plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(),
+                       field[:, 0].numpy(), field[:, 1].numpy(),
+                       units = 'xy', scale = 1,
+                       width = 3e-3, headwidth = 25, headlength = 25)
+
+plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue')
+
+fig.savefig('denoising_field.pdf', bbox_inches='tight')
+
+######################################################################