3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import argparse, math, sys
9 from copy import deepcopy
11 import torch, torchvision
14 import torch.nn.functional as F
16 ######################################################################
18 if torch.cuda.is_available():
19 torch.backends.cudnn.benchmark = True
20 device = torch.device('cuda')
22 device = torch.device('cpu')
24 ######################################################################
26 parser = argparse.ArgumentParser(
27 description = '''An implementation of a Mutual Information estimator with a deep model
29 Three different toy data-sets are implemented, each consists of
30 pairs of samples, that may be from different spaces:
32 (1) Two MNIST images of same class. The "true" MI is the log of the
33 number of used MNIST classes.
35 (2) One MNIST image and a pair of real numbers whose difference is
36 the class of the image. The "true" MI is the log of the number of
39 (3) Two 1d sequences, the first with a single peak, the second with
40 two peaks, and the height of the peak in the first is the
41 difference of timing of the peaks in the second. The "true" MI is
42 the log of the number of possible peak heights.''',
44 formatter_class = argparse.ArgumentDefaultsHelpFormatter
47 parser.add_argument('--data',
48 type = str, default = 'image_pair',
49 help = 'What data: image_pair, image_values_pair, sequence_pair')
51 parser.add_argument('--seed',
52 type = int, default = 0,
53 help = 'Random seed (default 0, < 0 is no seeding)')
55 parser.add_argument('--mnist_classes',
56 type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
57 help = 'What MNIST classes to use')
59 parser.add_argument('--nb_classes',
60 type = int, default = 2,
61 help = 'How many classes for sequences')
63 parser.add_argument('--nb_epochs',
64 type = int, default = 50,
65 help = 'How many epochs')
67 parser.add_argument('--batch_size',
68 type = int, default = 100,
71 parser.add_argument('--learning_rate',
72 type = float, default = 1e-3,
75 parser.add_argument('--independent', action = 'store_true',
76 help = 'Should the pair components be independent')
78 ######################################################################
80 args = parser.parse_args()
83 torch.manual_seed(args.seed)
85 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
87 ######################################################################
91 for k in range(target.max() + 1):
92 n = (target == k).sum().item()
93 if n > 0: probas.append(n)
94 probas = torch.tensor(probas).float()
95 probas /= probas.sum()
96 return - (probas * probas.log()).sum().item()
98 ######################################################################
100 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
101 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
102 train_target = train_set.train_labels.to(device)
104 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
105 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
106 test_target = test_set.test_labels.to(device)
108 mu, std = train_input.mean(), train_input.std()
109 train_input.sub_(mu).div_(std)
110 test_input.sub_(mu).div_(std)
112 ######################################################################
114 # Returns a triplet of tensors (a, b, c), where a and b contain each
115 # half of the samples, with a[i] and b[i] of same class for any i, and
116 # c is a 1d long tensor real classes
118 def create_image_pairs(train = False):
119 ua, ub, uc = [], [], []
122 input, target = train_input, train_target
124 input, target = test_input, test_target
126 for i in used_MNIST_classes:
127 used_indices = torch.arange(input.size(0), device = target.device)\
128 .masked_select(target == i.item())
129 x = input[used_indices]
130 x = x[torch.randperm(x.size(0))]
132 ua.append(x.narrow(0, 0, hs))
133 ub.append(x.narrow(0, hs, hs))
134 uc.append(target[used_indices])
139 perm = torch.randperm(a.size(0))
140 a = a[perm].contiguous()
143 perm = torch.randperm(a.size(0))
144 b = b[perm].contiguous()
148 ######################################################################
150 # Returns a triplet a, b, c where a are the standard MNIST images, c
151 # the classes, and b is a Nx2 tensor, with for every n:
153 # b[n, 0] ~ Uniform(0, 10)
154 # b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
156 def create_image_values_pairs(train = False):
160 input, target = train_input, train_target
162 input, target = test_input, test_target
164 m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
165 m[used_MNIST_classes] = 1
167 used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
169 input = input[used_indices].contiguous()
170 target = target[used_indices].contiguous()
175 b = a.new(a.size(0), 2)
176 b[:, 0].uniform_(0.0, 10.0)
177 b[:, 1].uniform_(0.0, 0.5)
180 b[:, 1] += b[:, 0] + \
181 used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
183 b[:, 1] += b[:, 0] + target.float()
187 ######################################################################
191 def create_sequences_pairs(train = False):
192 nb, length = 10000, 1024
195 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
197 hb = torch.randint(args.nb_classes, (nb, ), device = device)
201 pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
202 a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
203 a = a - pos.view(nb, 1)
204 a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
205 a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
206 noise = a.new(a.size()).normal_(0, noise_level)
209 pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
210 b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
211 b1 = b1 - pos.view(nb, 1)
212 b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
213 pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
214 # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
215 b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
216 b2 = b2 - pos.view(nb, 1)
217 b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
220 noise = b.new(b.size()).normal_(0, noise_level)
225 ######################################################################
227 class NetForImagePair(nn.Module):
230 self.features_a = nn.Sequential(
231 nn.Conv2d(1, 16, kernel_size = 5),
232 nn.MaxPool2d(3), nn.ReLU(),
233 nn.Conv2d(16, 32, kernel_size = 5),
234 nn.MaxPool2d(2), nn.ReLU(),
237 self.features_b = nn.Sequential(
238 nn.Conv2d(1, 16, kernel_size = 5),
239 nn.MaxPool2d(3), nn.ReLU(),
240 nn.Conv2d(16, 32, kernel_size = 5),
241 nn.MaxPool2d(2), nn.ReLU(),
244 self.fully_connected = nn.Sequential(
250 def forward(self, a, b):
251 a = self.features_a(a).view(a.size(0), -1)
252 b = self.features_b(b).view(b.size(0), -1)
253 x = torch.cat((a, b), 1)
254 return self.fully_connected(x)
256 ######################################################################
258 class NetForImageValuesPair(nn.Module):
261 self.features_a = nn.Sequential(
262 nn.Conv2d(1, 16, kernel_size = 5),
263 nn.MaxPool2d(3), nn.ReLU(),
264 nn.Conv2d(16, 32, kernel_size = 5),
265 nn.MaxPool2d(2), nn.ReLU(),
268 self.features_b = nn.Sequential(
269 nn.Linear(2, 32), nn.ReLU(),
270 nn.Linear(32, 32), nn.ReLU(),
271 nn.Linear(32, 128), nn.ReLU(),
274 self.fully_connected = nn.Sequential(
280 def forward(self, a, b):
281 a = self.features_a(a).view(a.size(0), -1)
282 b = self.features_b(b).view(b.size(0), -1)
283 x = torch.cat((a, b), 1)
284 return self.fully_connected(x)
286 ######################################################################
288 class NetForSequencePair(nn.Module):
290 def feature_model(self):
293 return nn.Sequential(
294 nn.Conv1d( 1, self.nc, kernel_size = kernel_size),
295 nn.AvgPool1d(pooling_size),
297 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
298 nn.AvgPool1d(pooling_size),
300 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
301 nn.AvgPool1d(pooling_size),
303 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
304 nn.AvgPool1d(pooling_size),
314 self.features_a = self.feature_model()
315 self.features_b = self.feature_model()
317 self.fully_connected = nn.Sequential(
318 nn.Linear(2 * self.nc, self.nh),
320 nn.Linear(self.nh, 1)
323 def forward(self, a, b):
324 a = a.view(a.size(0), 1, a.size(1))
325 a = self.features_a(a)
326 a = F.avg_pool1d(a, a.size(2))
328 b = b.view(b.size(0), 1, b.size(1))
329 b = self.features_b(b)
330 b = F.avg_pool1d(b, b.size(2))
332 x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
333 return self.fully_connected(x)
335 ######################################################################
337 if args.data == 'image_pair':
338 create_pairs = create_image_pairs
339 model = NetForImagePair()
341 elif args.data == 'image_values_pair':
342 create_pairs = create_image_values_pairs
343 model = NetForImageValuesPair()
345 elif args.data == 'sequence_pair':
346 create_pairs = create_sequences_pairs
347 model = NetForSequencePair()
349 ######################
351 a, b, c = create_pairs()
353 file = open(f'train_{k:02d}.dat', 'w')
354 for i in range(a.size(1)):
355 file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
357 ######################
360 raise Exception('Unknown data ' + args.data)
362 ######################################################################
365 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
369 input_a, input_b, classes = create_pairs(train = True)
371 for e in range(args.nb_epochs):
373 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
375 input_br = input_b[torch.randperm(input_b.size(0))]
379 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
380 input_b.split(args.batch_size),
381 input_br.split(args.batch_size)):
382 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
385 optimizer.zero_grad()
389 acc_mi /= (input_a.size(0) // args.batch_size)
391 print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
395 ######################################################################
398 input_a, input_b, classes = create_pairs(train = False)
400 input_br = input_b[torch.randperm(input_b.size(0))]
404 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
405 input_b.split(args.batch_size),
406 input_br.split(args.batch_size)):
407 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
410 acc_mi /= (input_a.size(0) // args.batch_size)
412 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
414 ######################################################################