X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=confidence.py;h=4530fbc12175e6e199394dc237353e5cdfba33e3;hp=ff4b395e8b0b96df1ebf0759cc580f802500fb81;hb=HEAD;hpb=4469498b31c1fb90cb2b1202dbaf86be0f2d18b0 diff --git a/confidence.py b/confidence.py index ff4b395..a586a3d 100755 --- a/confidence.py +++ b/confidence.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import torch, torchvision @@ -23,18 +28,26 @@ y = y.view(-1, 1) ###################################################################### -nh = 100 +nh = 400 -model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(), - nn.Linear(nh, nh), nn.ReLU(), - nn.Linear(nh, 1)) +model = nn.Sequential( + nn.Linear(1, nh), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(nh, nh), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(nh, 1), +) +model.train(True) criterion = nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for k in range(10000): loss = criterion(model(x), y) - if (k+1)%100 == 0: print(k+1, loss.item()) + if (k + 1) % 100 == 0: + print(k + 1, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() @@ -44,10 +57,22 @@ for k in range(10000): import matplotlib.pyplot as plt fig, ax = plt.subplots() + +u = torch.linspace(0, 1, 101) +v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1) +v = model(v).reshape(101, -1) +mean = v.mean(1) +std = v.std(1) + +ax.fill_between( + u.numpy(), + (mean - std).detach().numpy(), + (mean + std).detach().numpy(), + color="#e0e0e0", +) +ax.plot(u.numpy(), mean.detach().numpy(), color="red") ax.scatter(x.numpy(), y.numpy()) -u = torch.linspace(0, 1, 100).view(-1, 1) -ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red') plt.show() ######################################################################