Update.
[pytorch.git] / miniflow.py
index e4f5945..ad7b431 100755 (executable)
@@ -19,36 +19,40 @@ from torch.nn import functional as F
 
 ######################################################################
 
+
 def phi(x):
     p, std = 0.3, 0.2
-    mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \
-              p  * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
+    mu = (1 - p) * torch.exp(
+        LogProba((x - 0.5) / std, math.log(1 / std))
+    ) + p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std)))
     return mu
 
+
 def sample_phi(nb):
     p, std = 0.3, 0.2
     result = torch.empty(nb).normal_(0, std)
     result = result + torch.sign(torch.rand(result.size()) - p) / 2
     return result
 
+
 ######################################################################
 
-# START_LOG_PROBA
+
 def LogProba(x, ldj):
-    log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
+    log_p = ldj - 0.5 * (x**2 + math.log(2 * pi))
     return log_p
-# END_LOG_PROBA
+
 
 ######################################################################
 
-# START_MODEL
+
 class PiecewiseLinear(nn.Module):
     def __init__(self, nb, xmin, xmax):
-        super(PiecewiseLinear, self).__init__()
+        super().__init__()
         self.xmin = xmin
         self.xmax = xmax
         self.nb = nb
-        self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float))
+        self.alpha = nn.Parameter(torch.tensor([xmin], dtype=torch.float))
         mu = math.log((xmax - xmin) / nb)
         self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4))
 
@@ -59,7 +63,6 @@ class PiecewiseLinear(nn.Module):
         a = (u - n).clamp(0, 1)
         x = (1 - a) * y[n] + a * y[n + 1]
         return x
-# END_MODEL
 
     def invert(self, y):
         ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
@@ -68,9 +71,12 @@ class PiecewiseLinear(nn.Module):
         assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min()
         yk = ys[:, :-1]
         ykp1 = ys[:, 1:]
-        x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1)
+        x = self.xmin + (self.xmax - self.xmin) / self.nb * (
+            (yy >= yk) * (yy < ykp1).long() * (k + (yy - yk) / (ykp1 - yk))
+        ).sum(1)
         return x
 
+
 ######################################################################
 # Training
 
@@ -78,35 +84,33 @@ nb_samples = 25000
 nb_epochs = 250
 batch_size = 100
 
-model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4)
+model = PiecewiseLinear(nb=1001, xmin=-4, xmax=4)
 
 train_input = sample_phi(nb_samples)
 
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 criterion = nn.MSELoss()
 
 for k in range(nb_epochs):
     acc_loss = 0
 
-# START_OPTIMIZATION
     for input in train_input.split(batch_size):
         input.requires_grad_()
         output = model(input)
 
-        derivatives, = autograd.grad(
-            output.sum(), input,
-            retain_graph = True, create_graph = True
+        (derivatives,) = autograd.grad(
+            output.sum(), input, retain_graph=True, create_graph=True
         )
 
-        loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean()
+        loss = (0.5 * (output**2 + math.log(2 * pi)) - derivatives.log()).mean()
 
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
-# END_OPTIMIZATION
 
         acc_loss += loss.item()
-    if k%10 == 0: print(k, loss.item())
+    if k % 10 == 0:
+        print(k, loss.item())
 
 ######################################################################
 
@@ -137,33 +141,35 @@ ax = fig.add_subplot(1, 1, 1)
 # ax.set_ylim(-0.25, 1.25)
 # ax.axis('off')
 
-ax.plot(input, output, '-', color = 'tab:red')
+ax.plot(input, output, "-", color="tab:red")
 
-filename = 'miniflow_mapping.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+filename = "miniflow_mapping.pdf"
+print(f"Saving {filename}")
+fig.savefig(filename, bbox_inches="tight")
 
 # plt.show()
 
 ######################################################################
 
-green_dist = '#bfdfbf'
+green_dist = "#bfdfbf"
 
 fig = plt.figure()
 ax = fig.add_subplot(1, 1, 1)
 # ax.set_xlim(-4.5, 4.5)
 # ax.set_ylim(-0.1, 1.1)
-lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output))
-lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
+lines = list(
+    ([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output)
+)
+lc = mc.LineCollection(lines, color="tab:red", linewidth=0.1)
 ax.add_collection(lc)
-ax.axis('off')
+ax.axis("off")
 
-ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
-ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
+ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color=green_dist)
+ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color=green_dist)
 
-filename = 'miniflow_flow.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+filename = "miniflow_flow.pdf"
+print(f"Saving {filename}")
+fig.savefig(filename, bbox_inches="tight")
 
 # plt.show()
 
@@ -171,16 +177,16 @@ fig.savefig(filename, bbox_inches='tight')
 
 fig = plt.figure()
 ax = fig.add_subplot(1, 1, 1)
-ax.axis('off')
+ax.axis("off")
 
-ax.fill_between(input, 0, mu, color = green_dist)
+ax.fill_between(input, 0, mu, color=green_dist)
 # ax.plot(input, mu, '-', color = 'tab:blue')
 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
-ax.plot(input, mu_hat, '-', color = 'tab:red')
+ax.plot(input, mu_hat, "-", color="tab:red")
 
-filename = 'miniflow_dist.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+filename = "miniflow_dist.pdf"
+print(f"Saving {filename}")
+fig.savefig(filename, bbox_inches="tight")
 
 # plt.show()
 
@@ -188,40 +194,39 @@ fig.savefig(filename, bbox_inches='tight')
 
 fig = plt.figure()
 ax = fig.add_subplot(1, 1, 1)
-ax.axis('off')
+ax.axis("off")
 
 # ax.plot(input, mu, '-', color = 'tab:blue')
-ax.fill_between(input, 0, mu, color = green_dist)
+ax.fill_between(input, 0, mu, color=green_dist)
 # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red')
 
-filename = 'miniflow_target_dist.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+filename = "miniflow_target_dist.pdf"
+print(f"Saving {filename}")
+fig.savefig(filename, bbox_inches="tight")
 
 # plt.show()
 
 ######################################################################
 
-z = torch.empty(200).normal_()
-z = z[(z > -3) * (z < 3)]
+z = torch.empty(200).normal_()
+z = z[(z > -3) * (z < 3)]
 
-x = model.invert(z)
+x = model.invert(z)
 
-fig = plt.figure()
-ax = fig.add_subplot(1, 1, 1)
-ax.set_xlim(-4.5, 4.5)
-ax.set_ylim(-0.1, 1.1)
-lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
-lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
-ax.add_collection(lc)
-# ax.axis('off')
-
-# ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
-# ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
+# fig = plt.figure()
+# ax = fig.add_subplot(1, 1, 1)
+# ax.set_xlim(-4.5, 4.5)
+# ax.set_ylim(-0.1, 1.1)
+# lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
+# lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
+# ax.add_collection(lc)
+# # ax.axis('off')
 
-filename = 'miniflow_synth.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+# # ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
+# # ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
 
-# plt.show()
+# filename = 'miniflow_synth.pdf'
+# print(f'Saving {filename}')
+# fig.savefig(filename, bbox_inches='tight')
 
+# # plt.show()