Cleaning up. Now saving all the figures.
[pytorch.git] / denoising-ae-field.py
1 #!/usr/bin/env python
2
3 import math
4 import matplotlib.pyplot as plt
5
6 import torch
7 from torch import nn
8
9 ######################################################################
10
11 def data_zigzag(nb):
12     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
13     # zigzag
14     x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
15     y = a * 2.5 - 1.25
16     data = torch.cat((y, x), 1)
17     data = data @ torch.tensor([[1., -1.], [1., 1.]])
18     return data, 'zigzag'
19
20 def data_spiral(nb):
21     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
22     x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
23     y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
24     data = torch.cat((y, x), 1)
25     return data, 'spiral'
26
27 def data_penta(nb):
28     a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
29     x = a.cos()
30     y = a.sin()
31     data = torch.cat((y, x), 1)
32     data = data + data.new(data.size()).normal_(0, 0.05)
33     return data, 'penta'
34
35 ######################################################################
36
37 def train_model(data):
38     model = nn.Sequential(
39         nn.Linear(2, 100),
40         nn.ReLU(),
41         nn.Linear(100, 2)
42     )
43
44     batch_size, nb_epochs = 100, 1000
45     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
46     criterion = nn.MSELoss()
47
48     for e in range(nb_epochs):
49         acc_loss = 0
50         for input in data.split(batch_size):
51             noise = input.new(input.size()).normal_(0, 0.1)
52             output = model(input + noise)
53             loss = criterion(output, input)
54             acc_loss += loss.item()
55             optimizer.zero_grad()
56             loss.backward()
57             optimizer.step()
58         if (e+1)%100 == 0: print(e+1, acc_loss)
59
60     return model
61
62 ######################################################################
63
64 def save_image(data, data_name, model):
65     a = torch.linspace(-1.5, 1.5, 30)
66     x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
67     y = a.view(-1,  1, 1).expand(a.size(0), a.size(0), 1)
68     grid = torch.cat((y, x), 2).view(-1, 2)
69
70     # Take the origins of the arrows on the part of the grid closer than
71     # sqrt(0.1) to the data points
72     dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
73     origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
74
75     field = model(origins).detach() - origins
76
77     fig = plt.figure()
78     ax = fig.add_subplot(1, 1, 1)
79
80     ax.axis('off')
81     ax.set_xlim(-1.6, 1.6)
82     ax.set_ylim(-1.6, 1.6)
83     ax.set_aspect(1)
84
85     plot_field = ax.quiver(
86         origins[:, 0].numpy(), origins[:, 1].numpy(),
87         field[:, 0].numpy(), field[:, 1].numpy(),
88         units = 'xy', scale = 1,
89         width = 3e-3, headwidth = 25, headlength = 25
90     )
91
92     plot_data = ax.scatter(
93         data[:, 0].numpy(), data[:, 1].numpy(),
94         s = 1, color = 'tab:blue'
95     )
96
97     filename = f'denoising_field_{data_name}.pdf'
98     print(f'Saving {filename}')
99     fig.savefig(filename, bbox_inches='tight')
100
101 ######################################################################
102
103 for data_source in [ data_zigzag, data_spiral, data_penta ]:
104     data, data_name = data_source(1000)
105     data = data - data.mean(0)
106     model = train_model(data)
107     save_image(data, data_name, model)