Update.
[pytorch.git] / denoising-ae-field.py
index 67d2415..3ef0c80 100755 (executable)
 #!/usr/bin/env python
 
-import math
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
 
-import torch, torchvision
+# Written by Francois Fleuret <francois@fleuret.org>
 
+import math
+import matplotlib.pyplot as plt
+
+import torch
 from torch import nn
-from torch.nn import functional as F
 
-model = nn.Sequential(
-    nn.Linear(2, 100),
-    nn.ReLU(),
-    nn.Linear(100, 2)
-)
+######################################################################
+
+
+def data_rectangle(nb):
+    x = torch.rand(nb, 1) - 0.5
+    y = torch.rand(nb, 1) * 2 - 1
+    data = torch.cat((y, x), 1)
+    alpha = math.pi / 8
+    data = data @ torch.tensor(
+        [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]]
+    )
+    return data, "rectangle"
 
-############################################################
 
 def data_zigzag(nb):
     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
     # zigzag
-    x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
+    x = 0.4 * ((a - 0.5) * 5 * math.pi).cos()
     y = a * 2.5 - 1.25
     data = torch.cat((y, x), 1)
-    data = data @ torch.tensor([[1., -1.], [1., 1.]])
-    return data
+    data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]])
+    return data, "zigzag"
+
 
 def data_spiral(nb):
     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
     x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
     y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
     data = torch.cat((y, x), 1)
-    return data
+    return data, "spiral"
+
+
+def data_penta(nb):
+    a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
+    x = a.cos()
+    y = a.sin()
+    data = torch.cat((y, x), 1)
+    data = data + data.new(data.size()).normal_(0, 0.05)
+    return data, "penta"
+
 
 ######################################################################
 
-# data = data_spiral(1000)
-data = data_zigzag(1000)
 
-data = data - data.mean(0)
+def train_model(data):
+    model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2))
+
+    batch_size, nb_epochs = 100, 1000
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+    criterion = nn.MSELoss()
 
-batch_size, nb_epochs = 100, 1000
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
-criterion = nn.MSELoss()
+    for e in range(nb_epochs):
+        acc_loss = 0
+        for input in data.split(batch_size):
+            noise = input.new(input.size()).normal_(0, 0.1)
+            output = model(input + noise)
+            loss = criterion(output, input)
+            acc_loss += loss.item()
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+        if (e + 1) % 100 == 0:
+            print(e + 1, acc_loss)
+
+    return model
 
-for e in range(nb_epochs):
-    acc_loss = 0
-    for input in data.split(batch_size):
-        noise = input.new(input.size()).normal_(0, 0.1)
-        output = model(input + noise)
-        loss = criterion(output, input)
-        acc_loss += loss.item()
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-    if (e+1)%10 == 0: print(e+1, acc_loss)
 
 ######################################################################
 
-a = torch.linspace(-1.5, 1.5, 30)
-x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
-y = a.view(-1,  1, 1).expand(a.size(0), a.size(0), 1)
-grid = torch.cat((y, x), 2).view(-1, 2)
 
-# Take the origins of the arrows on the part of grid closer than 0.1
-# from the data points
-dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
-origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
+def save_image(data_name, model, data):
+    a = torch.linspace(-1.5, 1.5, 30)
+    x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
+    y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
+    grid = torch.cat((y, x), 2).view(-1, 2)
 
-field = model(origins).detach() - origins
+    # Take the origins of the arrows on the part of the grid closer than
+    # sqrt(0.1) to the data points
+    dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
+    origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
 
-######################################################################
+    field = model(origins).detach() - origins
 
-import matplotlib.pyplot as plt
+    fig = plt.figure()
+    ax = fig.add_subplot(1, 1, 1)
 
-fig = plt.figure()
-ax = fig.add_subplot(1, 1, 1)
+    ax.axis("off")
+    ax.set_xlim(-1.6, 1.6)
+    ax.set_ylim(-1.6, 1.6)
+    ax.set_aspect(1)
 
-ax.axis('off')
-ax.set_xlim(-1.6, 1.6)
-ax.set_ylim(-1.6, 1.6)
-ax.set_aspect(1)
+    plot_field = ax.quiver(
+        origins[:, 0].numpy(),
+        origins[:, 1].numpy(),
+        field[:, 0].numpy(),
+        field[:, 1].numpy(),
+        units="xy",
+        scale=1,
+        width=3e-3,
+        headwidth=25,
+        headlength=25,
+    )
 
-plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(),
-                       field[:, 0].numpy(), field[:, 1].numpy(),
-                       units = 'xy', scale = 1,
-                       width = 3e-3, headwidth = 25, headlength = 25)
+    plot_data = ax.scatter(
+        data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue"
+    )
 
-plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue')
+    filename = f"denoising_field_{data_name}.pdf"
+    print(f"Saving {filename}")
+    fig.savefig(filename, bbox_inches="tight")
 
-fig.savefig('denoising_field.pdf', bbox_inches='tight')
 
 ######################################################################
+
+for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]:
+    data, data_name = data_source(1000)
+    data = data - data.mean(0)
+    model = train_model(data)
+    save_image(data_name, model, data)