Update.
[pytorch.git] / denoising-ae-field.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import matplotlib.pyplot as plt
10
11 import torch
12 from torch import nn
13
14 ######################################################################
15
16
17 def data_rectangle(nb):
18     x = torch.rand(nb, 1) - 0.5
19     y = torch.rand(nb, 1) * 2 - 1
20     data = torch.cat((y, x), 1)
21     alpha = math.pi / 8
22     data = data @ torch.tensor(
23         [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]]
24     )
25     return data, "rectangle"
26
27
28 def data_zigzag(nb):
29     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
30     # zigzag
31     x = 0.4 * ((a - 0.5) * 5 * math.pi).cos()
32     y = a * 2.5 - 1.25
33     data = torch.cat((y, x), 1)
34     data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]])
35     return data, "zigzag"
36
37
38 def data_spiral(nb):
39     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
40     x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
41     y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
42     data = torch.cat((y, x), 1)
43     return data, "spiral"
44
45
46 def data_penta(nb):
47     a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
48     x = a.cos()
49     y = a.sin()
50     data = torch.cat((y, x), 1)
51     data = data + data.new(data.size()).normal_(0, 0.05)
52     return data, "penta"
53
54
55 ######################################################################
56
57
58 def train_model(data):
59     model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2))
60
61     batch_size, nb_epochs = 100, 1000
62     optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
63     criterion = nn.MSELoss()
64
65     for e in range(nb_epochs):
66         acc_loss = 0
67         for input in data.split(batch_size):
68             noise = input.new(input.size()).normal_(0, 0.1)
69             output = model(input + noise)
70             loss = criterion(output, input)
71             acc_loss += loss.item()
72             optimizer.zero_grad()
73             loss.backward()
74             optimizer.step()
75         if (e + 1) % 100 == 0:
76             print(e + 1, acc_loss)
77
78     return model
79
80
81 ######################################################################
82
83
84 def save_image(data_name, model, data):
85     a = torch.linspace(-1.5, 1.5, 30)
86     x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
87     y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
88     grid = torch.cat((y, x), 2).view(-1, 2)
89
90     # Take the origins of the arrows on the part of the grid closer than
91     # sqrt(0.1) to the data points
92     dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
93     origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
94
95     field = model(origins).detach() - origins
96
97     fig = plt.figure()
98     ax = fig.add_subplot(1, 1, 1)
99
100     ax.axis("off")
101     ax.set_xlim(-1.6, 1.6)
102     ax.set_ylim(-1.6, 1.6)
103     ax.set_aspect(1)
104
105     plot_field = ax.quiver(
106         origins[:, 0].numpy(),
107         origins[:, 1].numpy(),
108         field[:, 0].numpy(),
109         field[:, 1].numpy(),
110         units="xy",
111         scale=1,
112         width=3e-3,
113         headwidth=25,
114         headlength=25,
115     )
116
117     plot_data = ax.scatter(
118         data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue"
119     )
120
121     filename = f"denoising_field_{data_name}.pdf"
122     print(f"Saving {filename}")
123     fig.savefig(filename, bbox_inches="tight")
124
125
126 ######################################################################
127
128 for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]:
129     data, data_name = data_source(1000)
130     data = data - data.mean(0)
131     model = train_model(data)
132     save_image(data_name, model, data)