X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=denoising-ae-field.py;h=f96c23a1243a003b7c0f8cb038e0f63062ca4b9c;hp=8f748d11a3ff219922220b4ed9f14b3eb473a2e1;hb=HEAD;hpb=72e4bc5d20e153800f19a94e4bfd075adf30e3f3 diff --git a/denoising-ae-field.py b/denoising-ae-field.py index 8f748d1..3ef0c80 100755 --- a/denoising-ae-field.py +++ b/denoising-ae-field.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import matplotlib.pyplot as plt @@ -8,21 +13,35 @@ from torch import nn ###################################################################### + +def data_rectangle(nb): + x = torch.rand(nb, 1) - 0.5 + y = torch.rand(nb, 1) * 2 - 1 + data = torch.cat((y, x), 1) + alpha = math.pi / 8 + data = data @ torch.tensor( + [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]] + ) + return data, "rectangle" + + def data_zigzag(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) # zigzag - x = 0.4 * ((a-0.5) * 5 * math.pi).cos() + x = 0.4 * ((a - 0.5) * 5 * math.pi).cos() y = a * 2.5 - 1.25 data = torch.cat((y, x), 1) - data = data @ torch.tensor([[1., -1.], [1., 1.]]) - return data, 'zigzag' + data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]]) + return data, "zigzag" + def data_spiral(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5) y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5) data = torch.cat((y, x), 1) - return data, 'spiral' + return data, "spiral" + def data_penta(nb): a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1) @@ -30,19 +49,17 @@ def data_penta(nb): y = a.sin() data = torch.cat((y, x), 1) data = data + data.new(data.size()).normal_(0, 0.05) - return data, 'penta' + return data, "penta" + ###################################################################### + def train_model(data): - model = nn.Sequential( - nn.Linear(2, 100), - nn.ReLU(), - nn.Linear(100, 2) - ) + model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2)) batch_size, nb_epochs = 100, 1000 - optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.MSELoss() for e in range(nb_epochs): @@ -55,16 +72,19 @@ def train_model(data): optimizer.zero_grad() loss.backward() optimizer.step() - if (e+1)%100 == 0: print(e+1, acc_loss) + if (e + 1) % 100 == 0: + print(e + 1, acc_loss) return model + ###################################################################### -def save_image(data, data_name, model): + +def save_image(data_name, model, data): a = torch.linspace(-1.5, 1.5, 30) - x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1) - y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1) + x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1) + y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1) grid = torch.cat((y, x), 2).view(-1, 2) # Take the origins of the arrows on the part of the grid closer than @@ -77,31 +97,36 @@ def save_image(data, data_name, model): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) - ax.axis('off') + ax.axis("off") ax.set_xlim(-1.6, 1.6) ax.set_ylim(-1.6, 1.6) ax.set_aspect(1) plot_field = ax.quiver( - origins[:, 0].numpy(), origins[:, 1].numpy(), - field[:, 0].numpy(), field[:, 1].numpy(), - units = 'xy', scale = 1, - width = 3e-3, headwidth = 25, headlength = 25 + origins[:, 0].numpy(), + origins[:, 1].numpy(), + field[:, 0].numpy(), + field[:, 1].numpy(), + units="xy", + scale=1, + width=3e-3, + headwidth=25, + headlength=25, ) plot_data = ax.scatter( - data[:, 0].numpy(), data[:, 1].numpy(), - s = 1, color = 'tab:blue' + data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue" ) - filename = f'denoising_field_{data_name}.pdf' - print(f'Saving {filename}') - fig.savefig(filename, bbox_inches='tight') + filename = f"denoising_field_{data_name}.pdf" + print(f"Saving {filename}") + fig.savefig(filename, bbox_inches="tight") + ###################################################################### -for data_source in [ data_zigzag, data_spiral, data_penta ]: +for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]: data, data_name = data_source(1000) data = data - data.mean(0) model = train_model(data) - save_image(data, data_name, model) + save_image(data_name, model, data)