Update.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 23 Feb 2020 12:36:35 +0000 (13:36 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 23 Feb 2020 12:36:35 +0000 (13:36 +0100)
miniflow.py [new file with mode: 0755]

diff --git a/miniflow.py b/miniflow.py
new file mode 100755 (executable)
index 0000000..e4f5945
--- /dev/null
@@ -0,0 +1,227 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import matplotlib.pyplot as plt
+import matplotlib.collections as mc
+import numpy as np
+
+import math
+from math import pi
+
+import torch, torchvision
+
+from torch import nn, autograd
+from torch.nn import functional as F
+
+######################################################################
+
+def phi(x):
+    p, std = 0.3, 0.2
+    mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \
+              p  * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
+    return mu
+
+def sample_phi(nb):
+    p, std = 0.3, 0.2
+    result = torch.empty(nb).normal_(0, std)
+    result = result + torch.sign(torch.rand(result.size()) - p) / 2
+    return result
+
+######################################################################
+
+# START_LOG_PROBA
+def LogProba(x, ldj):
+    log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
+    return log_p
+# END_LOG_PROBA
+
+######################################################################
+
+# START_MODEL
+class PiecewiseLinear(nn.Module):
+    def __init__(self, nb, xmin, xmax):
+        super(PiecewiseLinear, self).__init__()
+        self.xmin = xmin
+        self.xmax = xmax
+        self.nb = nb
+        self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float))
+        mu = math.log((xmax - xmin) / nb)
+        self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
+
+    def forward(self, x):
+        y = self.alpha + self.xi.exp().cumsum(0)
+        u = self.nb * (x - self.xmin) / (self.xmax - self.xmin)
+        n = u.long().clamp(0, self.nb - 1)
+        a = (u - n).clamp(0, 1)
+        x = (1 - a) * y[n] + a * y[n + 1]
+        return x
+# END_MODEL
+
+    def invert(self, y):
+        ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
+        yy = y.view(-1, 1)
+        k = torch.arange(self.nb).view(1, -1)
+        assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
+        yk = ys[:, :-1]
+        ykp1 = ys[:, 1:]
+        x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1)
+        return x
+
+######################################################################
+# Training
+
+nb_samples = 25000
+nb_epochs = 250
+batch_size = 100
+
+model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4)
+
+train_input = sample_phi(nb_samples)
+
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
+criterion = nn.MSELoss()
+
+for k in range(nb_epochs):
+    acc_loss = 0
+
+# START_OPTIMIZATION
+    for input in train_input.split(batch_size):
+        input.requires_grad_()
+        output = model(input)
+
+        derivatives, = autograd.grad(
+            output.sum(), input,
+            retain_graph = True, create_graph = True
+        )
+
+        loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean()
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+# END_OPTIMIZATION
+
+        acc_loss += loss.item()
+    if k%10 == 0: print(k, loss.item())
+
+######################################################################
+
+input = torch.linspace(-3, 3, 175)
+
+mu = phi(input)
+mu_N = torch.exp(LogProba(input, 0))
+
+input.requires_grad_()
+output = model(input)
+
+grad = autograd.grad(output.sum(), input)[0]
+mu_hat = LogProba(output, grad.log()).detach().exp()
+
+######################################################################
+# FIGURES
+
+input = input.detach().numpy()
+output = output.detach().numpy()
+mu = mu.numpy()
+mu_hat = mu_hat.numpy()
+
+######################################################################
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+# ax.set_xlim(-5, 5)
+# ax.set_ylim(-0.25, 1.25)
+# ax.axis('off')
+
+ax.plot(input, output, '-', color = 'tab:red')
+
+filename = 'miniflow_mapping.pdf'
+print(f'Saving {filename}')
+fig.savefig(filename, bbox_inches='tight')
+
+# plt.show()
+
+######################################################################
+
+green_dist = '#bfdfbf'
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+# ax.set_xlim(-4.5, 4.5)
+# ax.set_ylim(-0.1, 1.1)
+lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output))
+lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
+ax.add_collection(lc)
+ax.axis('off')
+
+ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
+ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
+
+filename = 'miniflow_flow.pdf'
+print(f'Saving {filename}')
+fig.savefig(filename, bbox_inches='tight')
+
+# plt.show()
+
+######################################################################
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+ax.axis('off')
+
+ax.fill_between(input, 0, mu, color = green_dist)
+# ax.plot(input, mu, '-', color = 'tab:blue')
+# ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
+ax.plot(input, mu_hat, '-', color = 'tab:red')
+
+filename = 'miniflow_dist.pdf'
+print(f'Saving {filename}')
+fig.savefig(filename, bbox_inches='tight')
+
+# plt.show()
+
+######################################################################
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+ax.axis('off')
+
+# ax.plot(input, mu, '-', color = 'tab:blue')
+ax.fill_between(input, 0, mu, color = green_dist)
+# ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
+
+filename = 'miniflow_target_dist.pdf'
+print(f'Saving {filename}')
+fig.savefig(filename, bbox_inches='tight')
+
+# plt.show()
+
+######################################################################
+
+z = torch.empty(200).normal_()
+z = z[(z > -3) * (z < 3)]
+
+x = model.invert(z)
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+ax.set_xlim(-4.5, 4.5)
+ax.set_ylim(-0.1, 1.1)
+lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
+lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
+ax.add_collection(lc)
+# ax.axis('off')
+
+# ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
+# ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
+
+filename = 'miniflow_synth.pdf'
+print(f'Saving {filename}')
+fig.savefig(filename, bbox_inches='tight')
+
+# plt.show()
+