X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=denoising-ae-field.py;h=f96c23a1243a003b7c0f8cb038e0f63062ca4b9c;hp=8f748d11a3ff219922220b4ed9f14b3eb473a2e1;hb=7315192ca0a1d1fdbfcb85da97ee41b2c68cbc6e;hpb=72e4bc5d20e153800f19a94e4bfd075adf30e3f3 diff --git a/denoising-ae-field.py b/denoising-ae-field.py index 8f748d1..f96c23a 100755 --- a/denoising-ae-field.py +++ b/denoising-ae-field.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import matplotlib.pyplot as plt @@ -8,6 +13,19 @@ from torch import nn ###################################################################### +def data_rectangle(nb): + x = torch.rand(nb, 1) - 0.5 + y = torch.rand(nb, 1) * 2 - 1 + data = torch.cat((y, x), 1) + alpha = math.pi / 8 + data = data @ torch.tensor( + [ + [ math.cos(alpha), math.sin(alpha)], + [-math.sin(alpha), math.cos(alpha)] + ] + ) + return data, 'rectangle' + def data_zigzag(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) # zigzag @@ -61,7 +79,7 @@ def train_model(data): ###################################################################### -def save_image(data, data_name, model): +def save_image(data_name, model, data): a = torch.linspace(-1.5, 1.5, 30) x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1) y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1) @@ -100,8 +118,8 @@ def save_image(data, data_name, model): ###################################################################### -for data_source in [ data_zigzag, data_spiral, data_penta ]: +for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]: data, data_name = data_source(1000) data = data - data.mean(0) model = train_model(data) - save_image(data, data_name, model) + save_image(data_name, model, data)