Added the rectangle.
authorFrancois Fleuret <francois@fleuret.org>
Mon, 23 Dec 2019 16:39:26 +0000 (17:39 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 23 Dec 2019 16:39:26 +0000 (17:39 +0100)
denoising-ae-field.py

index 47e6ab4..f96c23a 100755 (executable)
@@ -13,6 +13,19 @@ from torch import nn
 
 ######################################################################
 
+def data_rectangle(nb):
+    x = torch.rand(nb, 1) - 0.5
+    y = torch.rand(nb, 1) * 2 - 1
+    data = torch.cat((y, x), 1)
+    alpha = math.pi / 8
+    data = data @ torch.tensor(
+        [
+            [ math.cos(alpha), math.sin(alpha)],
+            [-math.sin(alpha), math.cos(alpha)]
+        ]
+    )
+    return data, 'rectangle'
+
 def data_zigzag(nb):
     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
     # zigzag
@@ -105,7 +118,7 @@ def save_image(data_name, model, data):
 
 ######################################################################
 
-for data_source in [ data_zigzag, data_spiral, data_penta ]:
+for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]:
     data, data_name = data_source(1000)
     data = data - data.mean(0)
     model = train_model(data)