X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=miniflow.py;h=04b9a23b7e372b041ffbccb97fe2d2ca9bad0436;hb=77af470cdb98b4a97af6432ca1421c28062b9aae;hp=e4f594588f441775d63885e691bc4a19ce271955;hpb=78b4450fc60a5db62bc8ed50ec54e255a60f24e2;p=pytorch.git diff --git a/miniflow.py b/miniflow.py index e4f5945..04b9a23 100755 --- a/miniflow.py +++ b/miniflow.py @@ -33,18 +33,15 @@ def sample_phi(nb): ###################################################################### -# START_LOG_PROBA def LogProba(x, ldj): log_p = ldj - 0.5 * (x**2 + math.log(2*pi)) return log_p -# END_LOG_PROBA ###################################################################### -# START_MODEL class PiecewiseLinear(nn.Module): def __init__(self, nb, xmin, xmax): - super(PiecewiseLinear, self).__init__() + super().__init__() self.xmin = xmin self.xmax = xmax self.nb = nb @@ -59,7 +56,6 @@ class PiecewiseLinear(nn.Module): a = (u - n).clamp(0, 1) x = (1 - a) * y[n] + a * y[n + 1] return x -# END_MODEL def invert(self, y): ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1) @@ -88,7 +84,6 @@ criterion = nn.MSELoss() for k in range(nb_epochs): acc_loss = 0 -# START_OPTIMIZATION for input in train_input.split(batch_size): input.requires_grad_() output = model(input) @@ -103,7 +98,6 @@ for k in range(nb_epochs): optimizer.zero_grad() loss.backward() optimizer.step() -# END_OPTIMIZATION acc_loss += loss.item() if k%10 == 0: print(k, loss.item()) @@ -202,26 +196,25 @@ fig.savefig(filename, bbox_inches='tight') ###################################################################### -z = torch.empty(200).normal_() -z = z[(z > -3) * (z < 3)] +# z = torch.empty(200).normal_() +# z = z[(z > -3) * (z < 3)] -x = model.invert(z) +# x = model.invert(z) -fig = plt.figure() -ax = fig.add_subplot(1, 1, 1) -ax.set_xlim(-4.5, 4.5) -ax.set_ylim(-0.1, 1.1) -lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z)) -lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1) -ax.add_collection(lc) -# ax.axis('off') - -# ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist) -# ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist) +# fig = plt.figure() +# ax = fig.add_subplot(1, 1, 1) +# ax.set_xlim(-4.5, 4.5) +# ax.set_ylim(-0.1, 1.1) +# lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(x, z)) +# lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1) +# ax.add_collection(lc) +# # ax.axis('off') -filename = 'miniflow_synth.pdf' -print(f'Saving {filename}') -fig.savefig(filename, bbox_inches='tight') +# # ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist) +# # ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist) -# plt.show() +# filename = 'miniflow_synth.pdf' +# print(f'Saving {filename}') +# fig.savefig(filename, bbox_inches='tight') +# # plt.show()