Initial commit.
[pytorch.git] / miniflow.py
index e4f5945..04b9a23 100755 (executable)
@@ -33,18 +33,15 @@ def sample_phi(nb):
 
 ######################################################################
 
-# START_LOG_PROBA
 def LogProba(x, ldj):
     log_p = ldj - 0.5 * (x**2 + math.log(2*pi))
     return log_p
-# END_LOG_PROBA
 
 ######################################################################
 
-# START_MODEL
 class PiecewiseLinear(nn.Module):
     def __init__(self, nb, xmin, xmax):
-        super(PiecewiseLinear, self).__init__()
+        super().__init__()
         self.xmin = xmin
         self.xmax = xmax
         self.nb = nb
@@ -59,7 +56,6 @@ class PiecewiseLinear(nn.Module):
         a = (u - n).clamp(0, 1)
         x = (1 - a) * y[n] + a * y[n + 1]
         return x
-# END_MODEL
 
     def invert(self, y):
         ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1)
@@ -88,7 +84,6 @@ criterion = nn.MSELoss()
 for k in range(nb_epochs):
     acc_loss = 0
 
-# START_OPTIMIZATION
     for input in train_input.split(batch_size):
         input.requires_grad_()
         output = model(input)
@@ -103,7 +98,6 @@ for k in range(nb_epochs):
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
-# END_OPTIMIZATION
 
         acc_loss += loss.item()
     if k%10 == 0: print(k, loss.item())
@@ -202,26 +196,25 @@ fig.savefig(filename, bbox_inches='tight')
 
 ######################################################################
 
-z = torch.empty(200).normal_()
-z = z[(z > -3) * (z < 3)]
+z = torch.empty(200).normal_()
+z = z[(z > -3) * (z < 3)]
 
-x = model.invert(z)
+x = model.invert(z)
 
-fig = plt.figure()
-ax = fig.add_subplot(1, 1, 1)
-ax.set_xlim(-4.5, 4.5)
-ax.set_ylim(-0.1, 1.1)
-lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
-lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
-ax.add_collection(lc)
-# ax.axis('off')
-
-# ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
-# ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
+# fig = plt.figure()
+# ax = fig.add_subplot(1, 1, 1)
+# ax.set_xlim(-4.5, 4.5)
+# ax.set_ylim(-0.1, 1.1)
+# lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z))
+# lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1)
+# ax.add_collection(lc)
+# # ax.axis('off')
 
-filename = 'miniflow_synth.pdf'
-print(f'Saving {filename}')
-fig.savefig(filename, bbox_inches='tight')
+# # ax.fill_between(input,  0.52, mu_N * 0.2 + 0.52, color = green_dist)
+# # ax.fill_between(input, -0.30, mu   * 0.2 - 0.30, color = green_dist)
 
-# plt.show()
+# filename = 'miniflow_synth.pdf'
+# print(f'Saving {filename}')
+# fig.savefig(filename, bbox_inches='tight')
 
+# # plt.show()