Update.
[pytorch.git] / miniflow.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import matplotlib.pyplot as plt
9 import matplotlib.collections as mc
10 import numpy as np
11
12 import math
13 from math import pi
14
15 import torch, torchvision
16
17 from torch import nn, autograd
18 from torch.nn import functional as F
19
20 ######################################################################
21
22 def phi(x):
23     p, std = 0.3, 0.2
24     mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \
25               p  * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
26     return mu
27
28 def sample_phi(nb):
29     p, std = 0.3, 0.2
30     result = torch.empty(nb).normal_(0, std)
31     result = result + torch.sign(torch.rand(result.size()) - p) / 2
32     return result
33
34 ######################################################################
35
36 def LogProba(x, ldj):
37     log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
38     return log_p
39
40 ######################################################################
41
42 class PiecewiseLinear(nn.Module):
43     def __init__(self, nb, xmin, xmax):
44         super().__init__()
45         self.xmin = xmin
46         self.xmax = xmax
47         self.nb = nb
48         self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float))
49         mu = math.log((xmax - xmin) / nb)
50         self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
51
52     def forward(self, x):
53         y = self.alpha + self.xi.exp().cumsum(0)
54         u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
55         n = u.long().clamp(0, self.nb - 1)
56         a = (u - n).clamp(0, 1)
57         x = (1 - a) * y[n] + a * y[n + 1]
58         return x
59
60     def invert(self, y):
61         ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
62         yy = y.view(-1, 1)
63         k = torch.arange(self.nb).view(1, -1)
64         assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
65         yk = ys[:, :-1]
66         ykp1 = ys[:, 1:]
67         x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1)
68         return x
69
70 ######################################################################
71 # Training
72
73 nb_samples = 25000
74 nb_epochs = 250
75 batch_size = 100
76
77 model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4)
78
79 train_input = sample_phi(nb_samples)
80
81 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
82 criterion = nn.MSELoss()
83
84 for k in range(nb_epochs):
85     acc_loss = 0
86
87     for input in train_input.split(batch_size):
88         input.requires_grad_()
89         output = model(input)
90
91         derivatives, = autograd.grad(
92             output.sum(), input,
93             retain_graph = True, create_graph = True
94         )
95
96         loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean()
97
98         optimizer.zero_grad()
99         loss.backward()
100         optimizer.step()
101
102         acc_loss += loss.item()
103     if k%10 == 0: print(k, loss.item())
104
105 ######################################################################
106
107 input = torch.linspace(-3, 3, 175)
108
109 mu = phi(input)
110 mu_N = torch.exp(LogProba(input, 0))
111
112 input.requires_grad_()
113 output = model(input)
114
115 grad = autograd.grad(output.sum(), input)[0]
116 mu_hat = LogProba(output, grad.log()).detach().exp()
117
118 ######################################################################
119 # FIGURES
120
121 input = input.detach().numpy()
122 output = output.detach().numpy()
123 mu = mu.numpy()
124 mu_hat = mu_hat.numpy()
125
126 ######################################################################
127
128 fig = plt.figure()
129 ax = fig.add_subplot(1, 1, 1)
130 # ax.set_xlim(-5, 5)
131 # ax.set_ylim(-0.25, 1.25)
132 # ax.axis('off')
133
134 ax.plot(input, output, '-', color = 'tab:red')
135
136 filename = 'miniflow_mapping.pdf'
137 print(f'Saving {filename}')
138 fig.savefig(filename, bbox_inches='tight')
139
140 # plt.show()
141
142 ######################################################################
143
144 green_dist = '#bfdfbf'
145
146 fig = plt.figure()
147 ax = fig.add_subplot(1, 1, 1)
148 # ax.set_xlim(-4.5, 4.5)
149 # ax.set_ylim(-0.1, 1.1)
150 lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output))
151 lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
152 ax.add_collection(lc)
153 ax.axis('off')
154
155 ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
156 ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
157
158 filename = 'miniflow_flow.pdf'
159 print(f'Saving {filename}')
160 fig.savefig(filename, bbox_inches='tight')
161
162 # plt.show()
163
164 ######################################################################
165
166 fig = plt.figure()
167 ax = fig.add_subplot(1, 1, 1)
168 ax.axis('off')
169
170 ax.fill_between(input, 0, mu, color = green_dist)
171 # ax.plot(input, mu, '-', color = 'tab:blue')
172 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
173 ax.plot(input, mu_hat, '-', color = 'tab:red')
174
175 filename = 'miniflow_dist.pdf'
176 print(f'Saving {filename}')
177 fig.savefig(filename, bbox_inches='tight')
178
179 # plt.show()
180
181 ######################################################################
182
183 fig = plt.figure()
184 ax = fig.add_subplot(1, 1, 1)
185 ax.axis('off')
186
187 # ax.plot(input, mu, '-', color = 'tab:blue')
188 ax.fill_between(input, 0, mu, color = green_dist)
189 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
190
191 filename = 'miniflow_target_dist.pdf'
192 print(f'Saving {filename}')
193 fig.savefig(filename, bbox_inches='tight')
194
195 # plt.show()
196
197 ######################################################################
198
199 # z = torch.empty(200).normal_()
200 # z = z[(z > -3) * (z < 3)]
201
202 # x = model.invert(z)
203
204 # fig = plt.figure()
205 # ax = fig.add_subplot(1, 1, 1)
206 # ax.set_xlim(-4.5, 4.5)
207 # ax.set_ylim(-0.1, 1.1)
208 # lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
209 # lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
210 # ax.add_collection(lc)
211 # # ax.axis('off')
212
213 # # ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
214 # # ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
215
216 # filename = 'miniflow_synth.pdf'
217 # print(f'Saving {filename}')
218 # fig.savefig(filename, bbox_inches='tight')
219
220 # # plt.show()