Added the 5 cluster data-set.
[pytorch.git] / denoising-ae-field.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 model = nn.Sequential(
11     nn.Linear(2, 100),
12     nn.ReLU(),
13     nn.Linear(100, 2)
14 )
15
16 ######################################################################
17
18 def data_zigzag(nb):
19     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
20     # zigzag
21     x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
22     y = a * 2.5 - 1.25
23     data = torch.cat((y, x), 1)
24     data = data @ torch.tensor([[1., -1.], [1., 1.]])
25     return data
26
27 def data_spiral(nb):
28     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
29     x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
30     y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
31     data = torch.cat((y, x), 1)
32     return data
33
34 def data_penta(nb):
35     a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
36     x = a.cos()
37     y = a.sin()
38     data = torch.cat((y, x), 1)
39     data = data + data.new(data.size()).normal_(0, 0.05)
40     return data
41
42 ######################################################################
43
44 data = data_spiral(1000)
45 # data = data_zigzag(1000)
46 # data = data_penta(1000)
47
48 data = data - data.mean(0)
49
50 batch_size, nb_epochs = 100, 1000
51 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
52 criterion = nn.MSELoss()
53
54 for e in range(nb_epochs):
55     acc_loss = 0
56     for input in data.split(batch_size):
57         noise = input.new(input.size()).normal_(0, 0.1)
58         output = model(input + noise)
59         loss = criterion(output, input)
60         acc_loss += loss.item()
61         optimizer.zero_grad()
62         loss.backward()
63         optimizer.step()
64     if (e+1)%10 == 0: print(e+1, acc_loss)
65
66 ######################################################################
67
68 a = torch.linspace(-1.5, 1.5, 30)
69 x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
70 y = a.view(-1,  1, 1).expand(a.size(0), a.size(0), 1)
71 grid = torch.cat((y, x), 2).view(-1, 2)
72
73 # Take the origins of the arrows on the part of grid closer than 0.1
74 # from the data points
75 dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
76 origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
77
78 field = model(origins).detach() - origins
79
80 ######################################################################
81
82 import matplotlib.pyplot as plt
83
84 fig = plt.figure()
85 ax = fig.add_subplot(1, 1, 1)
86
87 ax.axis('off')
88 ax.set_xlim(-1.6, 1.6)
89 ax.set_ylim(-1.6, 1.6)
90 ax.set_aspect(1)
91
92 plot_field = ax.quiver(
93     origins[:, 0].numpy(), origins[:, 1].numpy(),
94     field[:, 0].numpy(), field[:, 1].numpy(),
95     units = 'xy', scale = 1,
96     width = 3e-3, headwidth = 25, headlength = 25
97 )
98
99 plot_data = ax.scatter(
100     data[:, 0].numpy(), data[:, 1].numpy(),
101     s = 1, color = 'tab:blue'
102 )
103
104 fig.savefig('denoising_field.pdf', bbox_inches='tight')
105
106 ######################################################################