projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
aacb2bf
)
Added the rectangle.
author
Francois Fleuret
<francois@fleuret.org>
Mon, 23 Dec 2019 16:39:26 +0000
(17:39 +0100)
committer
Francois Fleuret
<francois@fleuret.org>
Mon, 23 Dec 2019 16:39:26 +0000
(17:39 +0100)
denoising-ae-field.py
patch
|
blob
|
history
diff --git
a/denoising-ae-field.py
b/denoising-ae-field.py
index
47e6ab4
..
f96c23a
100755
(executable)
--- a/
denoising-ae-field.py
+++ b/
denoising-ae-field.py
@@
-13,6
+13,19
@@
from torch import nn
######################################################################
######################################################################
+def data_rectangle(nb):
+ x = torch.rand(nb, 1) - 0.5
+ y = torch.rand(nb, 1) * 2 - 1
+ data = torch.cat((y, x), 1)
+ alpha = math.pi / 8
+ data = data @ torch.tensor(
+ [
+ [ math.cos(alpha), math.sin(alpha)],
+ [-math.sin(alpha), math.cos(alpha)]
+ ]
+ )
+ return data, 'rectangle'
+
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
# zigzag
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
# zigzag
@@
-105,7
+118,7
@@
def save_image(data_name, model, data):
######################################################################
######################################################################
-for data_source in [ data_zigzag, data_spiral, data_penta ]:
+for data_source in [ data_
rectangle, data_
zigzag, data_spiral, data_penta ]:
data, data_name = data_source(1000)
data = data - data.mean(0)
model = train_model(data)
data, data_name = data_source(1000)
data = data - data.mean(0)
model = train_model(data)