X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=miniflow.py;h=04b9a23b7e372b041ffbccb97fe2d2ca9bad0436;hp=eb2d4c724163acd3e67e1fef1df86d4776913174;hb=HEAD;hpb=114939c1db199607c9c1f6df0d23e5c23b2915f2 diff --git a/miniflow.py b/miniflow.py index eb2d4c7..ad7b431 100755 --- a/miniflow.py +++ b/miniflow.py @@ -19,36 +19,40 @@ from torch.nn import functional as F ###################################################################### + def phi(x): p, std = 0.3, 0.2 - mu = (1 - p) * torch.exp(LogProba((x - 0.5) / std, math.log(1 / std))) + \ - p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std))) + mu = (1 - p) * torch.exp( + LogProba((x - 0.5) / std, math.log(1 / std)) + ) + p * torch.exp(LogProba((x + 0.5) / std, math.log(1 / std))) return mu + def sample_phi(nb): p, std = 0.3, 0.2 result = torch.empty(nb).normal_(0, std) result = result + torch.sign(torch.rand(result.size()) - p) / 2 return result + ###################################################################### -# START_LOG_PROBA + def LogProba(x, ldj): - log_p = ldj - 0.5 * (x**2 + math.log(2*pi)) + log_p = ldj - 0.5 * (x**2 + math.log(2 * pi)) return log_p -# END_LOG_PROBA + ###################################################################### -# START_MODEL + class PiecewiseLinear(nn.Module): def __init__(self, nb, xmin, xmax): super().__init__() self.xmin = xmin self.xmax = xmax self.nb = nb - self.alpha = nn.Parameter(torch.tensor([xmin], dtype = torch.float)) + self.alpha = nn.Parameter(torch.tensor([xmin], dtype=torch.float)) mu = math.log((xmax - xmin) / nb) self.xi = nn.Parameter(torch.empty(nb + 1).normal_(mu, 1e-4)) @@ -59,7 +63,6 @@ class PiecewiseLinear(nn.Module): a = (u - n).clamp(0, 1) x = (1 - a) * y[n] + a * y[n + 1] return x -# END_MODEL def invert(self, y): ys = self.alpha + self.xi.exp().cumsum(0).view(1, -1) @@ -68,9 +71,12 @@ class PiecewiseLinear(nn.Module): assert (y >= ys[0, 0]).min() and (y <= ys[0, self.nb]).min() yk = ys[:, :-1] ykp1 = ys[:, 1:] - x = self.xmin + (self.xmax - self.xmin)/self.nb * ((yy >= yk) * (yy < ykp1).long() * (k + (yy - yk)/(ykp1 - yk))).sum(1) + x = self.xmin + (self.xmax - self.xmin) / self.nb * ( + (yy >= yk) * (yy < ykp1).long() * (k + (yy - yk) / (ykp1 - yk)) + ).sum(1) return x + ###################################################################### # Training @@ -78,35 +84,33 @@ nb_samples = 25000 nb_epochs = 250 batch_size = 100 -model = PiecewiseLinear(nb = 1001, xmin = -4, xmax = 4) +model = PiecewiseLinear(nb=1001, xmin=-4, xmax=4) train_input = sample_phi(nb_samples) -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) criterion = nn.MSELoss() for k in range(nb_epochs): acc_loss = 0 -# START_OPTIMIZATION for input in train_input.split(batch_size): input.requires_grad_() output = model(input) - derivatives, = autograd.grad( - output.sum(), input, - retain_graph = True, create_graph = True + (derivatives,) = autograd.grad( + output.sum(), input, retain_graph=True, create_graph=True ) - loss = ( 0.5 * (output**2 + math.log(2*pi)) - derivatives.log() ).mean() + loss = (0.5 * (output**2 + math.log(2 * pi)) - derivatives.log()).mean() optimizer.zero_grad() loss.backward() optimizer.step() -# END_OPTIMIZATION acc_loss += loss.item() - if k%10 == 0: print(k, loss.item()) + if k % 10 == 0: + print(k, loss.item()) ###################################################################### @@ -137,33 +141,35 @@ ax = fig.add_subplot(1, 1, 1) # ax.set_ylim(-0.25, 1.25) # ax.axis('off') -ax.plot(input, output, '-', color = 'tab:red') +ax.plot(input, output, "-", color="tab:red") -filename = 'miniflow_mapping.pdf' -print(f'Saving {filename}') -fig.savefig(filename, bbox_inches='tight') +filename = "miniflow_mapping.pdf" +print(f"Saving {filename}") +fig.savefig(filename, bbox_inches="tight") # plt.show() ###################################################################### -green_dist = '#bfdfbf' +green_dist = "#bfdfbf" fig = plt.figure() ax = fig.add_subplot(1, 1, 1) # ax.set_xlim(-4.5, 4.5) # ax.set_ylim(-0.1, 1.1) -lines = list(([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output)) -lc = mc.LineCollection(lines, color = 'tab:red', linewidth = 0.1) +lines = list( + ([(x_in.item(), 0), (x_out.item(), 0.5)]) for (x_in, x_out) in zip(input, output) +) +lc = mc.LineCollection(lines, color="tab:red", linewidth=0.1) ax.add_collection(lc) -ax.axis('off') +ax.axis("off") -ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color = green_dist) -ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color = green_dist) +ax.fill_between(input, 0.52, mu_N * 0.2 + 0.52, color=green_dist) +ax.fill_between(input, -0.30, mu * 0.2 - 0.30, color=green_dist) -filename = 'miniflow_flow.pdf' -print(f'Saving {filename}') -fig.savefig(filename, bbox_inches='tight') +filename = "miniflow_flow.pdf" +print(f"Saving {filename}") +fig.savefig(filename, bbox_inches="tight") # plt.show() @@ -171,16 +177,16 @@ fig.savefig(filename, bbox_inches='tight') fig = plt.figure() ax = fig.add_subplot(1, 1, 1) -ax.axis('off') +ax.axis("off") -ax.fill_between(input, 0, mu, color = green_dist) +ax.fill_between(input, 0, mu, color=green_dist) # ax.plot(input, mu, '-', color = 'tab:blue') # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red') -ax.plot(input, mu_hat, '-', color = 'tab:red') +ax.plot(input, mu_hat, "-", color="tab:red") -filename = 'miniflow_dist.pdf' -print(f'Saving {filename}') -fig.savefig(filename, bbox_inches='tight') +filename = "miniflow_dist.pdf" +print(f"Saving {filename}") +fig.savefig(filename, bbox_inches="tight") # plt.show() @@ -188,15 +194,15 @@ fig.savefig(filename, bbox_inches='tight') fig = plt.figure() ax = fig.add_subplot(1, 1, 1) -ax.axis('off') +ax.axis("off") # ax.plot(input, mu, '-', color = 'tab:blue') -ax.fill_between(input, 0, mu, color = green_dist) +ax.fill_between(input, 0, mu, color=green_dist) # ax.step(input, mu_hat, '-', where='mid', color = 'tab:red') -filename = 'miniflow_target_dist.pdf' -print(f'Saving {filename}') -fig.savefig(filename, bbox_inches='tight') +filename = "miniflow_target_dist.pdf" +print(f"Saving {filename}") +fig.savefig(filename, bbox_inches="tight") # plt.show()