From 78b4450fc60a5db62bc8ed50ec54e255a60f24e2 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sun, 23 Feb 2020 13:36:35 +0100 Subject: [PATCH] Update. --- miniflow.py | 227 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100755 miniflow.py diff --git a/miniflow.py b/miniflow.py new file mode 100755 index 0000000..e4f5945 --- /dev/null +++ b/miniflow.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import matplotlib.pyplot as plt +import matplotlib.collections as mc +import numpy as np + +import math +from math import pi + +import torch, torchvision + +from torch import nn, autograd +from torch.nn import functional as F + +###################################################################### + +def phi(x): + p, std = 0.3, 0.2 + mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \ + p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std))) + return mu + +def sample_phi(nb): + p, std = 0.3, 0.2 + result = torch.empty(nb).normal_(0, std) + result = result + torch.sign(torch.rand(result.size()) - p) / 2 + return result + +###################################################################### + +# START_LOG_PROBA +def LogProba(x, ldj): + log_p = ldj - 0.5 * (x**2 + math.log(2*pi)) + return log_p +# END_LOG_PROBA + +###################################################################### + +# START_MODEL +class PiecewiseLinear(nn.Module): + def __init__(self, nb, xmin, xmax): + super(PiecewiseLinear, self).__init__() + self.xmin = xmin + self.xmax = xmax + self.nb = nb + self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float)) + mu = math.log((xmax - xmin) / nb) + self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4)) + + def forward(self, x): + y = self.alpha + self.xi.exp().cumsum(0) + u = self.nb * (x - self.xmin) / (self.xmax - self.xmin) + n = u.long().clamp(0, self.nb - 1) + a = (u - n).clamp(0, 1) + x = (1 - a) * y[n] + a * y[n + 1] + return x +# END_MODEL + + def invert(self, y): + ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1) + yy = y.view(-1, 1) + k = torch.arange(self.nb).view(1, -1) + assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min() + yk = ys[:, :-1] + ykp1 = ys[:, 1:] + x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1) + return x + +###################################################################### +# Training + +nb_samples = 25000 +nb_epochs = 250 +batch_size = 100 + +model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4) + +train_input = sample_phi(nb_samples) + +optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) +criterion = nn.MSELoss() + +for k in range(nb_epochs): + acc_loss = 0 + +# START_OPTIMIZATION + for input in train_input.split(batch_size): + input.requires_grad_() + output = model(input) + + derivatives, = autograd.grad( + output.sum(), input, + retain_graph = True, create_graph = True + ) + + loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean() + + optimizer.zero_grad() + loss.backward() + optimizer.step() +# END_OPTIMIZATION + + acc_loss += loss.item() + if k%10 == 0: print(k, loss.item()) + +###################################################################### + +input = torch.linspace(-3, 3, 175) + +mu = phi(input) +mu_N = torch.exp(LogProba(input, 0)) + +input.requires_grad_() +output = model(input) + +grad = autograd.grad(output.sum(), input)[0] +mu_hat = LogProba(output, grad.log()).detach().exp() + +###################################################################### +# FIGURES + +input = input.detach().numpy() +output = output.detach().numpy() +mu = mu.numpy() +mu_hat = mu_hat.numpy() + +###################################################################### + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +# ax.set_xlim(-5, 5) +# ax.set_ylim(-0.25, 1.25) +# ax.axis('off') + +ax.plot(input, output, '-', color = 'tab:red') + +filename = 'miniflow_mapping.pdf' +print(f'Saving {filename}') +fig.savefig(filename, bbox_inches='tight') + +# plt.show() + +###################################################################### + +green_dist = '#bfdfbf' + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +# ax.set_xlim(-4.5, 4.5) +# ax.set_ylim(-0.1, 1.1) +lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output)) +lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1) +ax.add_collection(lc) +ax.axis('off') + +ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist) +ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist) + +filename = 'miniflow_flow.pdf' +print(f'Saving {filename}') +fig.savefig(filename, bbox_inches='tight') + +# plt.show() + +###################################################################### + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +ax.axis('off') + +ax.fill_between(input, 0, mu, color = green_dist) +# ax.plot(input, mu, '-', color = 'tab:blue') +# ax.step(input, mu_hat, '-', where='mid', color = 'tab:red') +ax.plot(input, mu_hat, '-', color = 'tab:red') + +filename = 'miniflow_dist.pdf' +print(f'Saving {filename}') +fig.savefig(filename, bbox_inches='tight') + +# plt.show() + +###################################################################### + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +ax.axis('off') + +# ax.plot(input, mu, '-', color = 'tab:blue') +ax.fill_between(input, 0, mu, color = green_dist) +# ax.step(input, mu_hat, '-', where='mid', color = 'tab:red') + +filename = 'miniflow_target_dist.pdf' +print(f'Saving {filename}') +fig.savefig(filename, bbox_inches='tight') + +# plt.show() + +###################################################################### + +z = torch.empty(200).normal_() +z = z[(z > -3) * (z < 3)] + +x = model.invert(z) + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +ax.set_xlim(-4.5, 4.5) +ax.set_ylim(-0.1, 1.1) +lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z)) +lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1) +ax.add_collection(lc) +# ax.axis('off') + +# ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist) +# ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist) + +filename = 'miniflow_synth.pdf' +print(f'Saving {filename}') +fig.savefig(filename, bbox_inches='tight') + +# plt.show() + -- 2.20.1