Update.
authorFrancois Fleuret <francois.fleuret@idiap.ch>
Tue, 18 Dec 2018 08:55:05 +0000 (09:55 +0100)
committerFrancois Fleuret <francois.fleuret@idiap.ch>
Tue, 18 Dec 2018 08:55:05 +0000 (09:55 +0100)
confidence.py [new file with mode: 0755]
mine_mnist.py

diff --git a/confidence.py b/confidence.py
new file mode 100755 (executable)
index 0000000..ff4b395
--- /dev/null
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+nb = 100
+delta = 0.35
+x = torch.empty(nb).uniform_(0.0, delta)
+x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
+
+a = x * math.pi * 2 * 4
+b = x * math.pi * 2 * 3
+y = a.sin() + b
+
+x = x.view(-1, 1)
+y = y.view(-1, 1)
+
+######################################################################
+
+nh = 100
+
+model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
+                      nn.Linear(nh, nh), nn.ReLU(),
+                      nn.Linear(nh, 1))
+
+criterion = nn.MSELoss()
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+
+for k in range(10000):
+    loss = criterion(model(x), y)
+    if (k+1)%100 == 0: print(k+1, loss.item())
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+######################################################################
+
+import matplotlib.pyplot as plt
+
+fig, ax = plt.subplots()
+ax.scatter(x.numpy(), y.numpy())
+
+u = torch.linspace(0, 1, 100).view(-1, 1)
+ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red')
+plt.show()
+
+######################################################################
index 0c485b2..389544b 100755 (executable)
@@ -47,6 +47,9 @@ parser.add_argument('--batch_size',
                     type = int, default = 100,
                     help = 'Batch size')
 
+parser.add_argument('--independent', action = 'store_true',
+                    help = 'Should the pair components be independent')
+
 ######################################################################
 
 def entropy(target):
@@ -116,6 +119,9 @@ def create_image_pairs(train = False):
     c = torch.cat(uc, 0)
     perm = torch.randperm(a.size(0))
     a = a[perm].contiguous()
+
+    if args.independent:
+        perm = torch.randperm(a.size(0))
     b = b[perm].contiguous()
 
     return a, b, c
@@ -150,7 +156,11 @@ def create_image_values_pairs(train = False):
     b = a.new(a.size(0), 2)
     b[:, 0].uniform_(0.0, 10.0)
     b[:, 1].uniform_(0.0, 0.5)
-    b[:, 1] += b[:, 0] + target.float()
+
+    if args.independent:
+        b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+    else:
+        b[:, 1] += b[:, 0] + target.float()
 
     return a, b, c
 
@@ -161,8 +171,10 @@ def create_sequences_pairs(train = False):
     noise_level = 2e-2
 
     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
-    # hb = torch.randint(args.nb_classes, (nb, ), device = device)
-    hb = ha
+    if args.independent:
+        hb = torch.randint(args.nb_classes, (nb, ), device = device)
+    else:
+        hb = ha
 
     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
@@ -314,7 +326,7 @@ elif args.data == 'sequence_pair':
     ######################################################################
     a, b, c = create_pairs()
     for k in range(10):
-        file = open(f'/tmp/train_{k:02d}.dat', 'w')
+        file = open(f'train_{k:02d}.dat', 'w')
         for i in range(a.size(1)):
             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
         file.close()