Update.
[pytorch.git] / mi_estimator.py
index 02e9db9..1a167fe 100755 (executable)
@@ -1,22 +1,9 @@
 #!/usr/bin/env python
 
-#########################################################################
-# This program is free software: you can redistribute it and/or modify  #
-# it under the terms of the version 3 of the GNU General Public License #
-# as published by the Free Software Foundation.                         #
-#                                                                       #
-# This program is distributed in the hope that it will be useful, but   #
-# WITHOUT ANY WARRANTY; without even the implied warranty of            #
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
-# General Public License for more details.                              #
-#                                                                       #
-# You should have received a copy of the GNU General Public License     #
-# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
-#                                                                       #
-# Written by Francois Fleuret, (C) Idiap Research Institute             #
-#                                                                       #
-# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
-#########################################################################
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
 
 import argparse, math, sys
 from copy import deepcopy
@@ -30,14 +17,14 @@ import torch.nn.functional as F
 
 if torch.cuda.is_available():
     torch.backends.cudnn.benchmark = True
-    device = torch.device('cuda')
+    device = torch.device("cuda")
 else:
-    device = torch.device('cpu')
+    device = torch.device("cpu")
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description = '''An implementation of a Mutual Information estimator with a deep model
+    description="""An implementation of a Mutual Information estimator with a deep model
 
     Three different toy data-sets are implemented, each consists of
     pairs of samples, that may be from different spaces:
@@ -52,41 +39,43 @@ parser = argparse.ArgumentParser(
     (3) Two 1d sequences, the first with a single peak, the second with
     two peaks, and the height of the peak in the first is the
     difference of timing of the peaks in the second. The "true" MI is
-    the log of the number of possible peak heights.''',
-
-    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+    the log of the number of possible peak heights.""",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument('--data',
-                    type = str, default = 'image_pair',
-                    help = 'What data: image_pair, image_values_pair, sequence_pair')
+parser.add_argument(
+    "--data",
+    type=str,
+    default="image_pair",
+    help="What data: image_pair, image_values_pair, sequence_pair",
+)
 
-parser.add_argument('--seed',
-                    type = int, default = 0,
-                    help = 'Random seed (default 0, < 0 is no seeding)')
+parser.add_argument(
+    "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
+)
 
-parser.add_argument('--mnist_classes',
-                    type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
-                    help = 'What MNIST classes to use')
+parser.add_argument(
+    "--mnist_classes",
+    type=str,
+    default="0, 1, 3, 5, 6, 7, 8, 9",
+    help="What MNIST classes to use",
+)
 
-parser.add_argument('--nb_classes',
-                    type = int, default = 2,
-                    help = 'How many classes for sequences')
+parser.add_argument(
+    "--nb_classes", type=int, default=2, help="How many classes for sequences"
+)
 
-parser.add_argument('--nb_epochs',
-                    type = int, default = 50,
-                    help = 'How many epochs')
+parser.add_argument("--nb_epochs", type=int, default=50, help="How many epochs")
 
-parser.add_argument('--batch_size',
-                    type = int, default = 100,
-                    help = 'Batch size')
+parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
 
-parser.add_argument('--learning_rate',
-                    type = float, default = 1e-3,
-                    help = 'Batch size')
+parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
 
-parser.add_argument('--independent', action = 'store_true',
-                    help = 'Should the pair components be independent')
+parser.add_argument(
+    "--independent",
+    action="store_true",
+    help="Should the pair components be independent",
+)
 
 ######################################################################
 
@@ -95,26 +84,29 @@ args = parser.parse_args()
 if args.seed >= 0:
     torch.manual_seed(args.seed)
 
-used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
+used_MNIST_classes = torch.tensor(eval("[" + args.mnist_classes + "]"), device=device)
 
 ######################################################################
 
+
 def entropy(target):
     probas = []
     for k in range(target.max() + 1):
         n = (target == k).sum().item()
-        if n > 0: probas.append(n)
+        if n > 0:
+            probas.append(n)
     probas = torch.tensor(probas).float()
     probas /= probas.sum()
-    return - (probas * probas.log()).sum().item()
+    return -(probas * probas.log()).sum().item()
+
 
 ######################################################################
 
-train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
-train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
+train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
+train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
 train_target = train_set.train_labels.to(device)
 
-test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
+test_set = torchvision.datasets.MNIST("./data/mnist/", train=False, download=True)
 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
 test_target = test_set.test_labels.to(device)
 
@@ -128,7 +120,8 @@ test_input.sub_(mu).div_(std)
 # half of the samples, with a[i] and b[i] of same class for any i, and
 # c is a 1d long tensor real classes
 
-def create_image_pairs(train = False):
+
+def create_image_pairs(train=False):
     ua, ub, uc = [], [], []
 
     if train:
@@ -137,11 +130,12 @@ def create_image_pairs(train = False):
         input, target = test_input, test_target
 
     for i in used_MNIST_classes:
-        used_indices = torch.arange(input.size(0), device = target.device)\
-                            .masked_select(target == i.item())
+        used_indices = torch.arange(input.size(0), device=target.device).masked_select(
+            target == i.item()
+        )
         x = input[used_indices]
         x = x[torch.randperm(x.size(0))]
-        hs = x.size(0)//2
+        hs = x.size(0) // 2
         ua.append(x.narrow(0, 0, hs))
         ub.append(x.narrow(0, hs, hs))
         uc.append(target[used_indices])
@@ -158,6 +152,7 @@ def create_image_pairs(train = False):
 
     return a, b, c
 
+
 ######################################################################
 
 # Returns a triplet a, b, c where a are the standard MNIST images, c
@@ -166,7 +161,8 @@ def create_image_pairs(train = False):
 #   b[n, 0] ~ Uniform(0, 10)
 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
 
-def create_image_values_pairs(train = False):
+
+def create_image_values_pairs(train=False):
     ua, ub = [], []
 
     if train:
@@ -174,10 +170,12 @@ def create_image_values_pairs(train = False):
     else:
         input, target = test_input, test_target
 
-    m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
+    m = torch.zeros(
+        used_MNIST_classes.max() + 1, dtype=torch.uint8, device=target.device
+    )
     m[used_MNIST_classes] = 1
     m = m[target]
-    used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
+    used_indices = torch.arange(input.size(0), device=target.device).masked_select(m)
 
     input = input[used_indices].contiguous()
     target = target[used_indices].contiguous()
@@ -190,42 +188,46 @@ def create_image_values_pairs(train = False):
     b[:, 1].uniform_(0.0, 0.5)
 
     if args.independent:
-        b[:, 1] += b[:, 0] + \
-                   used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+        b[:, 1] += (
+            b[:, 0]
+            + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+        )
     else:
         b[:, 1] += b[:, 0] + target.float()
 
     return a, b, c
 
+
 ######################################################################
 
 #
 
-def create_sequences_pairs(train = False):
+
+def create_sequences_pairs(train=False):
     nb, length = 10000, 1024
     noise_level = 2e-2
 
-    ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
+    ha = torch.randint(args.nb_classes, (nb,), device=device) + 1
     if args.independent:
-        hb = torch.randint(args.nb_classes, (nb, ), device = device)
+        hb = torch.randint(args.nb_classes, (nb,), device=device)
     else:
         hb = ha
 
-    pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
-    a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+    pos = torch.empty(nb, device=device).uniform_(0.0, 0.9)
+    a = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
     a = a - pos.view(nb, 1)
     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
     noise = a.new(a.size()).normal_(0, noise_level)
     a = a + noise
 
-    pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
-    b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+    pos = torch.empty(nb, device=device).uniform_(0.0, 0.5)
+    b1 = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
     b1 = b1 - pos.view(nb, 1)
     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
     # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
-    b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+    b2 = torch.linspace(0, 1, length, device=device).view(1, -1).expand(nb, -1)
     b2 = b2 - pos.view(nb, 1)
     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
 
@@ -235,29 +237,33 @@ def create_sequences_pairs(train = False):
 
     return a, b, ha
 
+
 ######################################################################
 
+
 class NetForImagePair(nn.Module):
     def __init__(self):
-        super(NetForImagePair, self).__init__()
+        super().__init__()
         self.features_a = nn.Sequential(
-            nn.Conv2d(1, 16, kernel_size = 5),
-            nn.MaxPool2d(3), nn.ReLU(),
-            nn.Conv2d(16, 32, kernel_size = 5),
-            nn.MaxPool2d(2), nn.ReLU(),
+            nn.Conv2d(1, 16, kernel_size=5),
+            nn.MaxPool2d(3),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, kernel_size=5),
+            nn.MaxPool2d(2),
+            nn.ReLU(),
         )
 
         self.features_b = nn.Sequential(
-            nn.Conv2d(1, 16, kernel_size = 5),
-            nn.MaxPool2d(3), nn.ReLU(),
-            nn.Conv2d(16, 32, kernel_size = 5),
-            nn.MaxPool2d(2), nn.ReLU(),
+            nn.Conv2d(1, 16, kernel_size=5),
+            nn.MaxPool2d(3),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, kernel_size=5),
+            nn.MaxPool2d(2),
+            nn.ReLU(),
         )
 
         self.fully_connected = nn.Sequential(
-            nn.Linear(256, 200),
-            nn.ReLU(),
-            nn.Linear(200, 1)
+            nn.Linear(256, 200), nn.ReLU(), nn.Linear(200, 1)
         )
 
     def forward(self, a, b):
@@ -266,28 +272,33 @@ class NetForImagePair(nn.Module):
         x = torch.cat((a, b), 1)
         return self.fully_connected(x)
 
+
 ######################################################################
 
+
 class NetForImageValuesPair(nn.Module):
     def __init__(self):
-        super(NetForImageValuesPair, self).__init__()
+        super().__init__()
         self.features_a = nn.Sequential(
-            nn.Conv2d(1, 16, kernel_size = 5),
-            nn.MaxPool2d(3), nn.ReLU(),
-            nn.Conv2d(16, 32, kernel_size = 5),
-            nn.MaxPool2d(2), nn.ReLU(),
+            nn.Conv2d(1, 16, kernel_size=5),
+            nn.MaxPool2d(3),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, kernel_size=5),
+            nn.MaxPool2d(2),
+            nn.ReLU(),
         )
 
         self.features_b = nn.Sequential(
-            nn.Linear(2, 32), nn.ReLU(),
-            nn.Linear(32, 32), nn.ReLU(),
-            nn.Linear(32, 128), nn.ReLU(),
+            nn.Linear(2, 32),
+            nn.ReLU(),
+            nn.Linear(32, 32),
+            nn.ReLU(),
+            nn.Linear(32, 128),
+            nn.ReLU(),
         )
 
         self.fully_connected = nn.Sequential(
-            nn.Linear(256, 200),
-            nn.ReLU(),
-            nn.Linear(200, 1)
+            nn.Linear(256, 200), nn.ReLU(), nn.Linear(200, 1)
         )
 
     def forward(self, a, b):
@@ -296,30 +307,31 @@ class NetForImageValuesPair(nn.Module):
         x = torch.cat((a, b), 1)
         return self.fully_connected(x)
 
+
 ######################################################################
 
-class NetForSequencePair(nn.Module):
 
+class NetForSequencePair(nn.Module):
     def feature_model(self):
         kernel_size = 11
         pooling_size = 4
-        return  nn.Sequential(
-            nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
+        return nn.Sequential(
+            nn.Conv1d(1, self.nc, kernel_size=kernel_size),
             nn.AvgPool1d(pooling_size),
             nn.LeakyReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+            nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
             nn.AvgPool1d(pooling_size),
             nn.LeakyReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+            nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
             nn.AvgPool1d(pooling_size),
             nn.LeakyReLU(),
-            nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+            nn.Conv1d(self.nc, self.nc, kernel_size=kernel_size),
             nn.AvgPool1d(pooling_size),
             nn.LeakyReLU(),
         )
 
     def __init__(self):
-        super(NetForSequencePair, self).__init__()
+        super().__init__()
 
         self.nc = 32
         self.nh = 256
@@ -328,9 +340,7 @@ class NetForSequencePair(nn.Module):
         self.features_b = self.feature_model()
 
         self.fully_connected = nn.Sequential(
-            nn.Linear(2 * self.nc, self.nh),
-            nn.ReLU(),
-            nn.Linear(self.nh, 1)
+            nn.Linear(2 * self.nc, self.nh), nn.ReLU(), nn.Linear(self.nh, 1)
         )
 
     def forward(self, a, b):
@@ -345,17 +355,18 @@ class NetForSequencePair(nn.Module):
         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
         return self.fully_connected(x)
 
+
 ######################################################################
 
-if args.data == 'image_pair':
+if args.data == "image_pair":
     create_pairs = create_image_pairs
     model = NetForImagePair()
 
-elif args.data == 'image_values_pair':
+elif args.data == "image_values_pair":
     create_pairs = create_image_values_pairs
     model = NetForImageValuesPair()
 
-elif args.data == 'sequence_pair':
+elif args.data == "sequence_pair":
     create_pairs = create_sequences_pairs
     model = NetForSequencePair()
 
@@ -363,65 +374,70 @@ elif args.data == 'sequence_pair':
     ## Save for figures
     a, b, c = create_pairs()
     for k in range(10):
-        file = open(f'train_{k:02d}.dat', 'w')
+        file = open(f"train_{k:02d}.dat", "w")
         for i in range(a.size(1)):
-            file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
+            file.write(f"{a[k, i]:f} {b[k,i]:f}\n")
         file.close()
     ######################
 
 else:
-    raise Exception('Unknown data ' + args.data)
+    raise Exception("Unknown data " + args.data)
 
 ######################################################################
 # Train
 
-print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
+print(f"nb_parameters {sum(x.numel() for x in model.parameters())}")
 
 model.to(device)
 
-input_a, input_b, classes = create_pairs(train = True)
+input_a, input_b, classes = create_pairs(train=True)
 
 for e in range(args.nb_epochs):
-
-    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     input_br = input_b[torch.randperm(input_b.size(0))]
 
     acc_mi = 0.0
 
-    for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
-                                          input_b.split(args.batch_size),
-                                          input_br.split(args.batch_size)):
-        mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
+    for batch_a, batch_b, batch_br in zip(
+        input_a.split(args.batch_size),
+        input_b.split(args.batch_size),
+        input_br.split(args.batch_size),
+    ):
+        mi = (
+            model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
+        )
         acc_mi += mi.item()
-        loss = - mi
+        loss = -mi
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
 
-    acc_mi /= (input_a.size(0) // args.batch_size)
+    acc_mi /= input_a.size(0) // args.batch_size
 
-    print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
+    print(f"{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}")
 
     sys.stdout.flush()
 
 ######################################################################
 # Test
 
-input_a, input_b, classes = create_pairs(train = False)
+input_a, input_b, classes = create_pairs(train=False)
 
 input_br = input_b[torch.randperm(input_b.size(0))]
 
 acc_mi = 0.0
 
-for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
-                                      input_b.split(args.batch_size),
-                                      input_br.split(args.batch_size)):
+for batch_a, batch_b, batch_br in zip(
+    input_a.split(args.batch_size),
+    input_b.split(args.batch_size),
+    input_br.split(args.batch_size),
+):
     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
     acc_mi += mi.item()
 
-acc_mi /= (input_a.size(0) // args.batch_size)
+acc_mi /= input_a.size(0) // args.batch_size
 
-print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
+print(f"test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}")
 
 ######################################################################