projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
denoising-ae-field.py
diff --git
a/denoising-ae-field.py
b/denoising-ae-field.py
index
f96c23a
..
3ef0c80
100755
(executable)
--- a/
denoising-ae-field.py
+++ b/
denoising-ae-field.py
@@
-13,34
+13,35
@@
from torch import nn
######################################################################
######################################################################
+
def data_rectangle(nb):
x = torch.rand(nb, 1) - 0.5
y = torch.rand(nb, 1) * 2 - 1
data = torch.cat((y, x), 1)
alpha = math.pi / 8
data = data @ torch.tensor(
def data_rectangle(nb):
x = torch.rand(nb, 1) - 0.5
y = torch.rand(nb, 1) * 2 - 1
data = torch.cat((y, x), 1)
alpha = math.pi / 8
data = data @ torch.tensor(
- [
- [ math.cos(alpha), math.sin(alpha)],
- [-math.sin(alpha), math.cos(alpha)]
- ]
+ [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]]
)
)
- return data, 'rectangle'
+ return data, "rectangle"
+
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
# zigzag
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
# zigzag
- x = 0.4 * ((a
-
0.5) * 5 * math.pi).cos()
+ x = 0.4 * ((a
-
0.5) * 5 * math.pi).cos()
y = a * 2.5 - 1.25
data = torch.cat((y, x), 1)
y = a * 2.5 - 1.25
data = torch.cat((y, x), 1)
- data = data @ torch.tensor([[1., -1.], [1., 1.]])
- return data, 'zigzag'
+ data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]])
+ return data, "zigzag"
+
def data_spiral(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
data = torch.cat((y, x), 1)
def data_spiral(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
data = torch.cat((y, x), 1)
- return data, 'spiral'
+ return data, "spiral"
+
def data_penta(nb):
a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
def data_penta(nb):
a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
@@
-48,19
+49,17
@@
def data_penta(nb):
y = a.sin()
data = torch.cat((y, x), 1)
data = data + data.new(data.size()).normal_(0, 0.05)
y = a.sin()
data = torch.cat((y, x), 1)
data = data + data.new(data.size()).normal_(0, 0.05)
- return data, 'penta'
+ return data, "penta"
+
######################################################################
######################################################################
+
def train_model(data):
def train_model(data):
- model = nn.Sequential(
- nn.Linear(2, 100),
- nn.ReLU(),
- nn.Linear(100, 2)
- )
+ model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2))
batch_size, nb_epochs = 100, 1000
batch_size, nb_epochs = 100, 1000
- optimizer = torch.optim.Adam(model.parameters(), lr
=
1e-3)
+ optimizer = torch.optim.Adam(model.parameters(), lr
=
1e-3)
criterion = nn.MSELoss()
for e in range(nb_epochs):
criterion = nn.MSELoss()
for e in range(nb_epochs):
@@
-73,16
+72,19
@@
def train_model(data):
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()
- if (e+1)%100 == 0: print(e+1, acc_loss)
+ if (e + 1) % 100 == 0:
+ print(e + 1, acc_loss)
return model
return model
+
######################################################################
######################################################################
+
def save_image(data_name, model, data):
a = torch.linspace(-1.5, 1.5, 30)
def save_image(data_name, model, data):
a = torch.linspace(-1.5, 1.5, 30)
- x = a.view(
1, -1, 1).expand(a.size(0), a.size(0), 1)
- y = a.view(-1,
1, 1).expand(a.size(0), a.size(0), 1)
+ x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
+ y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
grid = torch.cat((y, x), 2).view(-1, 2)
# Take the origins of the arrows on the part of the grid closer than
grid = torch.cat((y, x), 2).view(-1, 2)
# Take the origins of the arrows on the part of the grid closer than
@@
-95,30
+97,35
@@
def save_image(data_name, model, data):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
- ax.axis(
'off'
)
+ ax.axis(
"off"
)
ax.set_xlim(-1.6, 1.6)
ax.set_ylim(-1.6, 1.6)
ax.set_aspect(1)
plot_field = ax.quiver(
ax.set_xlim(-1.6, 1.6)
ax.set_ylim(-1.6, 1.6)
ax.set_aspect(1)
plot_field = ax.quiver(
- origins[:, 0].numpy(), origins[:, 1].numpy(),
- field[:, 0].numpy(), field[:, 1].numpy(),
- units = 'xy', scale = 1,
- width = 3e-3, headwidth = 25, headlength = 25
+ origins[:, 0].numpy(),
+ origins[:, 1].numpy(),
+ field[:, 0].numpy(),
+ field[:, 1].numpy(),
+ units="xy",
+ scale=1,
+ width=3e-3,
+ headwidth=25,
+ headlength=25,
)
plot_data = ax.scatter(
)
plot_data = ax.scatter(
- data[:, 0].numpy(), data[:, 1].numpy(),
- s = 1, color = 'tab:blue'
+ data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue"
)
)
- filename = f'denoising_field_{data_name}.pdf'
- print(f'Saving {filename}')
- fig.savefig(filename, bbox_inches='tight')
+ filename = f"denoising_field_{data_name}.pdf"
+ print(f"Saving {filename}")
+ fig.savefig(filename, bbox_inches="tight")
+
######################################################################
######################################################################
-for data_source in [
data_rectangle, data_zigzag, data_spiral, data_penta
]:
+for data_source in [
data_rectangle, data_zigzag, data_spiral, data_penta
]:
data, data_name = data_source(1000)
data = data - data.mean(0)
model = train_model(data)
data, data_name = data_source(1000)
data = data - data.mean(0)
model = train_model(data)