Update.
[pytorch.git] / denoising-ae-field.py
index 2aa3648..3ef0c80 100755 (executable)
@@ -1,5 +1,10 @@
 #!/usr/bin/env python
 
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
 import math
 import matplotlib.pyplot as plt
 
@@ -8,21 +13,35 @@ from torch import nn
 
 ######################################################################
 
+
+def data_rectangle(nb):
+    x = torch.rand(nb, 1) - 0.5
+    y = torch.rand(nb, 1) * 2 - 1
+    data = torch.cat((y, x), 1)
+    alpha = math.pi / 8
+    data = data @ torch.tensor(
+        [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]]
+    )
+    return data, "rectangle"
+
+
 def data_zigzag(nb):
     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
     # zigzag
-    x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
+    x = 0.4 * ((a - 0.5) * 5 * math.pi).cos()
     y = a * 2.5 - 1.25
     data = torch.cat((y, x), 1)
-    data = data @ torch.tensor([[1., -1.], [1., 1.]])
-    return data, 'zigzag'
+    data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]])
+    return data, "zigzag"
+
 
 def data_spiral(nb):
     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
     x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
     y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
     data = torch.cat((y, x), 1)
-    return data, 'spiral'
+    return data, "spiral"
+
 
 def data_penta(nb):
     a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
@@ -30,19 +49,17 @@ def data_penta(nb):
     y = a.sin()
     data = torch.cat((y, x), 1)
     data = data + data.new(data.size()).normal_(0, 0.05)
-    return data, 'penta'
+    return data, "penta"
+
 
 ######################################################################
 
+
 def train_model(data):
-    model = nn.Sequential(
-        nn.Linear(2, 100),
-        nn.ReLU(),
-        nn.Linear(100, 2)
-    )
+    model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2))
 
     batch_size, nb_epochs = 100, 1000
-    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
     criterion = nn.MSELoss()
 
     for e in range(nb_epochs):
@@ -55,16 +72,19 @@ def train_model(data):
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
-        if (e+1)%100 == 0: print(e+1, acc_loss)
+        if (e + 1) % 100 == 0:
+            print(e + 1, acc_loss)
 
     return model
 
+
 ######################################################################
 
+
 def save_image(data_name, model, data):
     a = torch.linspace(-1.5, 1.5, 30)
-    x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
-    y = a.view(-1,  1, 1).expand(a.size(0), a.size(0), 1)
+    x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
+    y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
     grid = torch.cat((y, x), 2).view(-1, 2)
 
     # Take the origins of the arrows on the part of the grid closer than
@@ -77,30 +97,35 @@ def save_image(data_name, model, data):
     fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1)
 
-    ax.axis('off')
+    ax.axis("off")
     ax.set_xlim(-1.6, 1.6)
     ax.set_ylim(-1.6, 1.6)
     ax.set_aspect(1)
 
     plot_field = ax.quiver(
-        origins[:, 0].numpy(), origins[:, 1].numpy(),
-        field[:, 0].numpy(), field[:, 1].numpy(),
-        units = 'xy', scale = 1,
-        width = 3e-3, headwidth = 25, headlength = 25
+        origins[:, 0].numpy(),
+        origins[:, 1].numpy(),
+        field[:, 0].numpy(),
+        field[:, 1].numpy(),
+        units="xy",
+        scale=1,
+        width=3e-3,
+        headwidth=25,
+        headlength=25,
     )
 
     plot_data = ax.scatter(
-        data[:, 0].numpy(), data[:, 1].numpy(),
-        s = 1, color = 'tab:blue'
+        data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue"
     )
 
-    filename = f'denoising_field_{data_name}.pdf'
-    print(f'Saving {filename}')
-    fig.savefig(filename, bbox_inches='tight')
+    filename = f"denoising_field_{data_name}.pdf"
+    print(f"Saving {filename}")
+    fig.savefig(filename, bbox_inches="tight")
+
 
 ######################################################################
 
-for data_source in [ data_zigzag, data_spiral, data_penta ]:
+for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]:
     data, data_name = data_source(1000)
     data = data - data.mean(0)
     model = train_model(data)