Update.
[pytorch.git] / confidence.py
index ff4b395..1be3420 100755 (executable)
@@ -23,14 +23,17 @@ y = y.view(-1, 1)
 
 ######################################################################
 
-nh = 100
+nh = 400
 
 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
+                      nn.Dropout(0.25),
                       nn.Linear(nh, nh), nn.ReLU(),
+                      nn.Dropout(0.25),
                       nn.Linear(nh, 1))
 
+model.train(True)
 criterion = nn.MSELoss()
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
 
 for k in range(10000):
     loss = criterion(model(x), y)
@@ -44,10 +47,17 @@ for k in range(10000):
 import matplotlib.pyplot as plt
 
 fig, ax = plt.subplots()
+
+u = torch.linspace(0, 1, 101)
+v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1)
+v = model(v).reshape(101, -1)
+mean = v.mean(1)
+std = v.std(1)
+
+ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0')
+ax.plot(u.numpy(), mean.detach().numpy(), color = 'red')
 ax.scatter(x.numpy(), y.numpy())
 
-u = torch.linspace(0, 1, 100).view(-1, 1)
-ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red')
 plt.show()
 
 ######################################################################