X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=confidence.py;h=4530fbc12175e6e199394dc237353e5cdfba33e3;hp=1be342039819bb3bd855c13a85e5828fb0dc9b21;hb=HEAD;hpb=75267f198e8f6cf476cb73d2846653494d7164b6 diff --git a/confidence.py b/confidence.py index 1be3420..a586a3d 100755 --- a/confidence.py +++ b/confidence.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math import torch, torchvision @@ -25,19 +30,24 @@ y = y.view(-1, 1) nh = 400 -model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(), - nn.Dropout(0.25), - nn.Linear(nh, nh), nn.ReLU(), - nn.Dropout(0.25), - nn.Linear(nh, 1)) +model = nn.Sequential( + nn.Linear(1, nh), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(nh, nh), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(nh, 1), +) model.train(True) criterion = nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for k in range(10000): loss = criterion(model(x), y) - if (k+1)%100 == 0: print(k+1, loss.item()) + if (k + 1) % 100 == 0: + print(k + 1, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() @@ -54,8 +64,13 @@ v = model(v).reshape(101, -1) mean = v.mean(1) std = v.std(1) -ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0') -ax.plot(u.numpy(), mean.detach().numpy(), color = 'red') +ax.fill_between( + u.numpy(), + (mean - std).detach().numpy(), + (mean + std).detach().numpy(), + color="#e0e0e0", +) +ax.plot(u.numpy(), mean.detach().numpy(), color="red") ax.scatter(x.numpy(), y.numpy()) plt.show()