Update.
[pytorch.git] / confidence.py
index 1be3420..a586a3d 100755 (executable)
@@ -1,5 +1,10 @@
 #!/usr/bin/env python
 
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
 import math
 
 import torch, torchvision
@@ -25,19 +30,24 @@ y = y.view(-1, 1)
 
 nh = 400
 
-model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
-                      nn.Dropout(0.25),
-                      nn.Linear(nh, nh), nn.ReLU(),
-                      nn.Dropout(0.25),
-                      nn.Linear(nh, 1))
+model = nn.Sequential(
+    nn.Linear(1, nh),
+    nn.ReLU(),
+    nn.Dropout(0.25),
+    nn.Linear(nh, nh),
+    nn.ReLU(),
+    nn.Dropout(0.25),
+    nn.Linear(nh, 1),
+)
 
 model.train(True)
 criterion = nn.MSELoss()
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 
 for k in range(10000):
     loss = criterion(model(x), y)
-    if (k+1)%100 == 0: print(k+1, loss.item())
+    if (k + 1) % 100 == 0:
+        print(k + 1, loss.item())
     optimizer.zero_grad()
     loss.backward()
     optimizer.step()
@@ -54,8 +64,13 @@ v = model(v).reshape(101, -1)
 mean = v.mean(1)
 std = v.std(1)
 
-ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0')
-ax.plot(u.numpy(), mean.detach().numpy(), color = 'red')
+ax.fill_between(
+    u.numpy(),
+    (mean - std).detach().numpy(),
+    (mean + std).detach().numpy(),
+    color="#e0e0e0",
+)
+ax.plot(u.numpy(), mean.detach().numpy(), color="red")
 ax.scatter(x.numpy(), y.numpy())
 
 plt.show()