+import torch.nn.functional as F
+
+######################################################################
+
+if torch.cuda.is_available():
+ torch.backends.cudnn.benchmark = True
+ device = torch.device('cuda')
+else:
+ device = torch.device('cpu')
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description = '''An implementation of a Mutual Information estimator with a deep model
+
+Three different toy data-sets are implemented:
+
+ (1) Two MNIST images of same class. The "true" MI is the log of the
+ number of used MNIST classes.
+
+ (2) One MNIST image and a pair of real numbers whose difference is
+ the class of the image. The "true" MI is the log of the number of
+ used MNIST classes.
+
+ (3) Two 1d sequences, the first with a single peak, the second with
+ two peaks, and the height of the peak in the first is the
+ difference of timing of the peaks in the second. The "true" MI is
+ the log of the number of possible peak heights.''',
+
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--data',
+ type = str, default = 'image_pair',
+ help = 'What data: image_pair, image_values_pair, sequence_pair')
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
+parser.add_argument('--mnist_classes',
+ type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
+ help = 'What MNIST classes to use')
+
+parser.add_argument('--nb_classes',
+ type = int, default = 2,
+ help = 'How many classes for sequences')
+
+parser.add_argument('--nb_epochs',
+ type = int, default = 50,
+ help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+ type = int, default = 100,
+ help = 'Batch size')
+
+parser.add_argument('--learning_rate',
+ type = float, default = 1e-3,
+ help = 'Batch size')
+
+parser.add_argument('--independent', action = 'store_true',
+ help = 'Should the pair components be independent')