X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=confidence.py;h=1be342039819bb3bd855c13a85e5828fb0dc9b21;hp=ff4b395e8b0b96df1ebf0759cc580f802500fb81;hb=75267f198e8f6cf476cb73d2846653494d7164b6;hpb=4469498b31c1fb90cb2b1202dbaf86be0f2d18b0 diff --git a/confidence.py b/confidence.py index ff4b395..1be3420 100755 --- a/confidence.py +++ b/confidence.py @@ -23,14 +23,17 @@ y = y.view(-1, 1) ###################################################################### -nh = 100 +nh = 400 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(), + nn.Dropout(0.25), nn.Linear(nh, nh), nn.ReLU(), + nn.Dropout(0.25), nn.Linear(nh, 1)) +model.train(True) criterion = nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) +optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) for k in range(10000): loss = criterion(model(x), y) @@ -44,10 +47,17 @@ for k in range(10000): import matplotlib.pyplot as plt fig, ax = plt.subplots() + +u = torch.linspace(0, 1, 101) +v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1) +v = model(v).reshape(101, -1) +mean = v.mean(1) +std = v.std(1) + +ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0') +ax.plot(u.numpy(), mean.detach().numpy(), color = 'red') ax.scatter(x.numpy(), y.numpy()) -u = torch.linspace(0, 1, 100).view(-1, 1) -ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red') plt.show() ######################################################################