X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=denoising-ae-field.py;h=f96c23a1243a003b7c0f8cb038e0f63062ca4b9c;hp=67d2415209b8bacfa1b7eb3e33b365b0bd848984;hb=HEAD;hpb=93ce0d3dc04fb72d098366020a0fb4b3451dee0f diff --git a/denoising-ae-field.py b/denoising-ae-field.py index 67d2415..3ef0c80 100755 --- a/denoising-ae-field.py +++ b/denoising-ae-field.py @@ -1,92 +1,132 @@ #!/usr/bin/env python -import math +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ -import torch, torchvision +# Written by Francois Fleuret +import math +import matplotlib.pyplot as plt + +import torch from torch import nn -from torch.nn import functional as F -model = nn.Sequential( - nn.Linear(2, 100), - nn.ReLU(), - nn.Linear(100, 2) -) +###################################################################### + + +def data_rectangle(nb): + x = torch.rand(nb, 1) - 0.5 + y = torch.rand(nb, 1) * 2 - 1 + data = torch.cat((y, x), 1) + alpha = math.pi / 8 + data = data @ torch.tensor( + [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]] + ) + return data, "rectangle" -############################################################ def data_zigzag(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) # zigzag - x = 0.4 * ((a-0.5) * 5 * math.pi).cos() + x = 0.4 * ((a - 0.5) * 5 * math.pi).cos() y = a * 2.5 - 1.25 data = torch.cat((y, x), 1) - data = data @ torch.tensor([[1., -1.], [1., 1.]]) - return data + data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]]) + return data, "zigzag" + def data_spiral(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5) y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5) data = torch.cat((y, x), 1) - return data + return data, "spiral" + + +def data_penta(nb): + a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1) + x = a.cos() + y = a.sin() + data = torch.cat((y, x), 1) + data = data + data.new(data.size()).normal_(0, 0.05) + return data, "penta" + ###################################################################### -# data = data_spiral(1000) -data = data_zigzag(1000) -data = data - data.mean(0) +def train_model(data): + model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2)) + + batch_size, nb_epochs = 100, 1000 + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.MSELoss() -batch_size, nb_epochs = 100, 1000 -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) -criterion = nn.MSELoss() + for e in range(nb_epochs): + acc_loss = 0 + for input in data.split(batch_size): + noise = input.new(input.size()).normal_(0, 0.1) + output = model(input + noise) + loss = criterion(output, input) + acc_loss += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + if (e + 1) % 100 == 0: + print(e + 1, acc_loss) + + return model -for e in range(nb_epochs): - acc_loss = 0 - for input in data.split(batch_size): - noise = input.new(input.size()).normal_(0, 0.1) - output = model(input + noise) - loss = criterion(output, input) - acc_loss += loss.item() - optimizer.zero_grad() - loss.backward() - optimizer.step() - if (e+1)%10 == 0: print(e+1, acc_loss) ###################################################################### -a = torch.linspace(-1.5, 1.5, 30) -x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1) -y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1) -grid = torch.cat((y, x), 2).view(-1, 2) -# Take the origins of the arrows on the part of grid closer than 0.1 -# from the data points -dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0] -origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)] +def save_image(data_name, model, data): + a = torch.linspace(-1.5, 1.5, 30) + x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1) + y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1) + grid = torch.cat((y, x), 2).view(-1, 2) -field = model(origins).detach() - origins + # Take the origins of the arrows on the part of the grid closer than + # sqrt(0.1) to the data points + dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0] + origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)] -###################################################################### + field = model(origins).detach() - origins -import matplotlib.pyplot as plt + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) -fig = plt.figure() -ax = fig.add_subplot(1, 1, 1) + ax.axis("off") + ax.set_xlim(-1.6, 1.6) + ax.set_ylim(-1.6, 1.6) + ax.set_aspect(1) -ax.axis('off') -ax.set_xlim(-1.6, 1.6) -ax.set_ylim(-1.6, 1.6) -ax.set_aspect(1) + plot_field = ax.quiver( + origins[:, 0].numpy(), + origins[:, 1].numpy(), + field[:, 0].numpy(), + field[:, 1].numpy(), + units="xy", + scale=1, + width=3e-3, + headwidth=25, + headlength=25, + ) -plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(), - field[:, 0].numpy(), field[:, 1].numpy(), - units = 'xy', scale = 1, - width = 3e-3, headwidth = 25, headlength = 25) + plot_data = ax.scatter( + data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue" + ) -plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue') + filename = f"denoising_field_{data_name}.pdf" + print(f"Saving {filename}") + fig.savefig(filename, bbox_inches="tight") -fig.savefig('denoising_field.pdf', bbox_inches='tight') ###################################################################### + +for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]: + data, data_name = data_source(1000) + data = data - data.mean(0) + model = train_model(data) + save_image(data_name, model, data)