X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=confidence.py;h=4530fbc12175e6e199394dc237353e5cdfba33e3;hp=ff4b395e8b0b96df1ebf0759cc580f802500fb81;hb=e916a8624b6a09737696c124f35059030f0f20e4;hpb=4469498b31c1fb90cb2b1202dbaf86be0f2d18b0 diff --git a/confidence.py b/confidence.py index ff4b395..4530fbc 100755 --- a/confidence.py +++ b/confidence.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import torch, torchvision @@ -23,14 +28,17 @@ y = y.view(-1, 1) ###################################################################### -nh = 100 +nh = 400 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(), + nn.Dropout(0.25), nn.Linear(nh, nh), nn.ReLU(), + nn.Dropout(0.25), nn.Linear(nh, 1)) +model.train(True) criterion = nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) +optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) for k in range(10000): loss = criterion(model(x), y) @@ -44,10 +52,17 @@ for k in range(10000): import matplotlib.pyplot as plt fig, ax = plt.subplots() + +u = torch.linspace(0, 1, 101) +v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1) +v = model(v).reshape(101, -1) +mean = v.mean(1) +std = v.std(1) + +ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0') +ax.plot(u.numpy(), mean.detach().numpy(), color = 'red') ax.scatter(x.numpy(), y.numpy()) -u = torch.linspace(0, 1, 100).view(-1, 1) -ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red') plt.show() ######################################################################