3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
19 x = torch.empty(nb).uniform_(0.0, delta)
20 x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
22 a = x * math.pi * 2 * 4
23 b = x * math.pi * 2 * 3
29 ######################################################################
33 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
35 nn.Linear(nh, nh), nn.ReLU(),
40 criterion = nn.MSELoss()
41 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
43 for k in range(10000):
44 loss = criterion(model(x), y)
45 if (k+1)%100 == 0: print(k+1, loss.item())
50 ######################################################################
52 import matplotlib.pyplot as plt
54 fig, ax = plt.subplots()
56 u = torch.linspace(0, 1, 101)
57 v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1)
58 v = model(v).reshape(101, -1)
62 ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0')
63 ax.plot(u.numpy(), mean.detach().numpy(), color = 'red')
64 ax.scatter(x.numpy(), y.numpy())
68 ######################################################################