Update.
authorFrancois Fleuret <francois.fleuret@idiap.ch>
Fri, 4 Jan 2019 14:33:06 +0000 (15:33 +0100)
committerFrancois Fleuret <francois.fleuret@idiap.ch>
Fri, 4 Jan 2019 14:33:06 +0000 (15:33 +0100)
confidence.py
mine_mnist.py

index ff4b395..1be3420 100755 (executable)
@@ -23,14 +23,17 @@ y = y.view(-1, 1)
 
 ######################################################################
 
-nh = 100
+nh = 400
 
 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
+                      nn.Dropout(0.25),
                       nn.Linear(nh, nh), nn.ReLU(),
+                      nn.Dropout(0.25),
                       nn.Linear(nh, 1))
 
+model.train(True)
 criterion = nn.MSELoss()
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
 
 for k in range(10000):
     loss = criterion(model(x), y)
@@ -44,10 +47,17 @@ for k in range(10000):
 import matplotlib.pyplot as plt
 
 fig, ax = plt.subplots()
+
+u = torch.linspace(0, 1, 101)
+v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1)
+v = model(v).reshape(101, -1)
+mean = v.mean(1)
+std = v.std(1)
+
+ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0')
+ax.plot(u.numpy(), mean.detach().numpy(), color = 'red')
 ax.scatter(x.numpy(), y.numpy())
 
-u = torch.linspace(0, 1, 100).view(-1, 1)
-ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red')
 plt.show()
 
 ######################################################################
index 389544b..1d69640 100755 (executable)
@@ -1,5 +1,22 @@
 #!/usr/bin/env python
 
+#########################################################################
+# This program is free software: you can redistribute it and/or modify  #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation.                         #
+#                                                                       #
+# This program is distributed in the hope that it will be useful, but   #
+# WITHOUT ANY WARRANTY; without even the implied warranty of            #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
+# General Public License for more details.                              #
+#                                                                       #
+# You should have received a copy of the GNU General Public License     #
+# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
+#                                                                       #
+# Written by and Copyright (C) Francois Fleuret                         #
+# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
+#########################################################################
+
 import argparse, math, sys
 from copy import deepcopy
 
@@ -19,13 +36,28 @@ else:
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description = 'An implementation of Mutual Information estimator with a deep model',
+    description = '''An implementation of a Mutual Information estimator with a deep model
+
+Three different toy data-sets are implemented:
+
+ (1) Two MNIST images of same class. The "true" MI is the log of the
+     number of used MNIST classes.
+
+ (2) One MNIST image and a pair of real numbers whose difference is
+     the class of the image. The "true" MI is the log of the number of
+     used MNIST classes.
+
+ (3) Two 1d sequences, the first with a single peak, the second with
+     two peaks, and the height of the peak in the first is the
+     difference of timing of the peaks in the second. The "true" MI is
+     the log of the number of possible peak heights.''',
+
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
 parser.add_argument('--data',
                     type = str, default = 'image_pair',
-                    help = 'What data')
+                    help = 'What data: image_pair, image_values_pair, sequence_pair')
 
 parser.add_argument('--seed',
                     type = int, default = 0,
@@ -47,9 +79,23 @@ parser.add_argument('--batch_size',
                     type = int, default = 100,
                     help = 'Batch size')
 
+parser.add_argument('--learning_rate',
+                    type = float, default = 1e-3,
+                    help = 'Batch size')
+
 parser.add_argument('--independent', action = 'store_true',
                     help = 'Should the pair components be independent')
 
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
+
 ######################################################################
 
 def entropy(target):
@@ -61,21 +107,6 @@ def entropy(target):
     probas /= probas.sum()
     return - (probas * probas.log()).sum().item()
 
-def robust_log_mean_exp(x):
-    # a = x.max()
-    # return (x-a).exp().mean().log() + a
-    # a = x.max()
-    return x.exp().mean().log()
-
-######################################################################
-
-args = parser.parse_args()
-
-if args.seed >= 0:
-    torch.manual_seed(args.seed)
-
-used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
-
 ######################################################################
 
 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
@@ -129,7 +160,7 @@ def create_image_pairs(train = False):
 ######################################################################
 
 # Returns a triplet a, b, c where a are the standard MNIST images, c
-# the classes, and b is a Nx2 tensor, eith for every n:
+# the classes, and b is a Nx2 tensor, with for every n:
 #
 #   b[n, 0] ~ Uniform(0, 10)
 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
@@ -158,7 +189,8 @@ def create_image_values_pairs(train = False):
     b[:, 1].uniform_(0.0, 0.5)
 
     if args.independent:
-        b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+        b[:, 1] += b[:, 0] + \
+                   used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
     else:
         b[:, 1] += b[:, 0] + target.float()
 
@@ -189,6 +221,7 @@ def create_sequences_pairs(train = False):
     b1 = b1 - pos.view(nb, 1)
     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
+    # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
     b2 = b2 - pos.view(nb, 1)
     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
@@ -317,40 +350,43 @@ class NetForSequencePair(nn.Module):
 if args.data == 'image_pair':
     create_pairs = create_image_pairs
     model = NetForImagePair()
+
 elif args.data == 'image_values_pair':
     create_pairs = create_image_values_pairs
     model = NetForImageValuesPair()
+
 elif args.data == 'sequence_pair':
     create_pairs = create_sequences_pairs
     model = NetForSequencePair()
-    ######################################################################
+
+    ## Save for figures
     a, b, c = create_pairs()
     for k in range(10):
         file = open(f'train_{k:02d}.dat', 'w')
         for i in range(a.size(1)):
             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
         file.close()
-    # exit(0)
-    ######################################################################
+
 else:
     raise Exception('Unknown data ' + args.data)
 
 ######################################################################
+# Train
 
-print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
+print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
 
 model.to(device)
 
+input_a, input_b, classes = create_pairs(train = True)
+
 for e in range(args.nb_epochs):
 
-    input_a, input_b, classes = create_pairs(train = True)
+    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
 
     input_br = input_b[torch.randperm(input_b.size(0))]
 
     acc_mi = 0.0
 
-    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
-
     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
                                           input_b.split(args.batch_size),
                                           input_br.split(args.batch_size)):
@@ -363,11 +399,12 @@ for e in range(args.nb_epochs):
 
     acc_mi /= (input_a.size(0) // args.batch_size)
 
-    print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
+    print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
 
     sys.stdout.flush()
 
 ######################################################################
+# Test
 
 input_a, input_b, classes = create_pairs(train = False)
 
@@ -383,6 +420,6 @@ for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
 
 acc_mi /= (input_a.size(0) // args.batch_size)
 
-print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))
+print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
 
 ######################################################################