3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation. #
8 # This program is distributed in the hope that it will be useful, but #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
11 # General Public License for more details. #
13 # You should have received a copy of the GNU General Public License #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
16 # Written by and Copyright (C) Francois Fleuret #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports #
18 #########################################################################
20 import argparse, math, sys
21 from copy import deepcopy
23 import torch, torchvision
26 import torch.nn.functional as F
28 ######################################################################
30 if torch.cuda.is_available():
31 torch.backends.cudnn.benchmark = True
32 device = torch.device('cuda')
34 device = torch.device('cpu')
36 ######################################################################
38 parser = argparse.ArgumentParser(
39 description = '''An implementation of a Mutual Information estimator with a deep model
41 Three different toy data-sets are implemented:
43 (1) Two MNIST images of same class. The "true" MI is the log of the
44 number of used MNIST classes.
46 (2) One MNIST image and a pair of real numbers whose difference is
47 the class of the image. The "true" MI is the log of the number of
50 (3) Two 1d sequences, the first with a single peak, the second with
51 two peaks, and the height of the peak in the first is the
52 difference of timing of the peaks in the second. The "true" MI is
53 the log of the number of possible peak heights.''',
55 formatter_class = argparse.ArgumentDefaultsHelpFormatter
58 parser.add_argument('--data',
59 type = str, default = 'image_pair',
60 help = 'What data: image_pair, image_values_pair, sequence_pair')
62 parser.add_argument('--seed',
63 type = int, default = 0,
64 help = 'Random seed (default 0, < 0 is no seeding)')
66 parser.add_argument('--mnist_classes',
67 type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
68 help = 'What MNIST classes to use')
70 parser.add_argument('--nb_classes',
71 type = int, default = 2,
72 help = 'How many classes for sequences')
74 parser.add_argument('--nb_epochs',
75 type = int, default = 50,
76 help = 'How many epochs')
78 parser.add_argument('--batch_size',
79 type = int, default = 100,
82 parser.add_argument('--learning_rate',
83 type = float, default = 1e-3,
86 parser.add_argument('--independent', action = 'store_true',
87 help = 'Should the pair components be independent')
90 ######################################################################
92 args = parser.parse_args()
95 torch.manual_seed(args.seed)
97 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
99 ######################################################################
103 for k in range(target.max() + 1):
104 n = (target == k).sum().item()
105 if n > 0: probas.append(n)
106 probas = torch.tensor(probas).float()
107 probas /= probas.sum()
108 return - (probas * probas.log()).sum().item()
110 ######################################################################
112 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
113 train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
114 train_target = train_set.train_labels.to(device)
116 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
117 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
118 test_target = test_set.test_labels.to(device)
120 mu, std = train_input.mean(), train_input.std()
121 train_input.sub_(mu).div_(std)
122 test_input.sub_(mu).div_(std)
124 ######################################################################
126 # Returns a triplet of tensors (a, b, c), where a and b contain each
127 # half of the samples, with a[i] and b[i] of same class for any i, and
128 # c is a 1d long tensor real classes
130 def create_image_pairs(train = False):
131 ua, ub, uc = [], [], []
134 input, target = train_input, train_target
136 input, target = test_input, test_target
138 for i in used_MNIST_classes:
139 used_indices = torch.arange(input.size(0), device = target.device)\
140 .masked_select(target == i.item())
141 x = input[used_indices]
142 x = x[torch.randperm(x.size(0))]
144 ua.append(x.narrow(0, 0, hs))
145 ub.append(x.narrow(0, hs, hs))
146 uc.append(target[used_indices])
151 perm = torch.randperm(a.size(0))
152 a = a[perm].contiguous()
155 perm = torch.randperm(a.size(0))
156 b = b[perm].contiguous()
160 ######################################################################
162 # Returns a triplet a, b, c where a are the standard MNIST images, c
163 # the classes, and b is a Nx2 tensor, with for every n:
165 # b[n, 0] ~ Uniform(0, 10)
166 # b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
168 def create_image_values_pairs(train = False):
172 input, target = train_input, train_target
174 input, target = test_input, test_target
176 m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
177 m[used_MNIST_classes] = 1
179 used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
181 input = input[used_indices].contiguous()
182 target = target[used_indices].contiguous()
187 b = a.new(a.size(0), 2)
188 b[:, 0].uniform_(0.0, 10.0)
189 b[:, 1].uniform_(0.0, 0.5)
192 b[:, 1] += b[:, 0] + \
193 used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
195 b[:, 1] += b[:, 0] + target.float()
199 ######################################################################
201 def create_sequences_pairs(train = False):
202 nb, length = 10000, 1024
205 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
207 hb = torch.randint(args.nb_classes, (nb, ), device = device)
211 pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
212 a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
213 a = a - pos.view(nb, 1)
214 a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
215 a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
216 noise = a.new(a.size()).normal_(0, noise_level)
219 pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
220 b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
221 b1 = b1 - pos.view(nb, 1)
222 b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
223 pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
224 # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
225 b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
226 b2 = b2 - pos.view(nb, 1)
227 b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
230 noise = b.new(b.size()).normal_(0, noise_level)
233 # a = (a - a.mean()) / a.std()
234 # b = (b - b.mean()) / b.std()
238 ######################################################################
240 class NetForImagePair(nn.Module):
242 super(NetForImagePair, self).__init__()
243 self.features_a = nn.Sequential(
244 nn.Conv2d(1, 16, kernel_size = 5),
245 nn.MaxPool2d(3), nn.ReLU(),
246 nn.Conv2d(16, 32, kernel_size = 5),
247 nn.MaxPool2d(2), nn.ReLU(),
250 self.features_b = nn.Sequential(
251 nn.Conv2d(1, 16, kernel_size = 5),
252 nn.MaxPool2d(3), nn.ReLU(),
253 nn.Conv2d(16, 32, kernel_size = 5),
254 nn.MaxPool2d(2), nn.ReLU(),
257 self.fully_connected = nn.Sequential(
263 def forward(self, a, b):
264 a = self.features_a(a).view(a.size(0), -1)
265 b = self.features_b(b).view(b.size(0), -1)
266 x = torch.cat((a, b), 1)
267 return self.fully_connected(x)
269 ######################################################################
271 class NetForImageValuesPair(nn.Module):
273 super(NetForImageValuesPair, self).__init__()
274 self.features_a = nn.Sequential(
275 nn.Conv2d(1, 16, kernel_size = 5),
276 nn.MaxPool2d(3), nn.ReLU(),
277 nn.Conv2d(16, 32, kernel_size = 5),
278 nn.MaxPool2d(2), nn.ReLU(),
281 self.features_b = nn.Sequential(
282 nn.Linear(2, 32), nn.ReLU(),
283 nn.Linear(32, 32), nn.ReLU(),
284 nn.Linear(32, 128), nn.ReLU(),
287 self.fully_connected = nn.Sequential(
293 def forward(self, a, b):
294 a = self.features_a(a).view(a.size(0), -1)
295 b = self.features_b(b).view(b.size(0), -1)
296 x = torch.cat((a, b), 1)
297 return self.fully_connected(x)
299 ######################################################################
301 class NetForSequencePair(nn.Module):
303 def feature_model(self):
306 return nn.Sequential(
307 nn.Conv1d( 1, self.nc, kernel_size = kernel_size),
308 nn.AvgPool1d(pooling_size),
310 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
311 nn.AvgPool1d(pooling_size),
313 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
314 nn.AvgPool1d(pooling_size),
316 nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
317 nn.AvgPool1d(pooling_size),
322 super(NetForSequencePair, self).__init__()
327 self.features_a = self.feature_model()
328 self.features_b = self.feature_model()
330 self.fully_connected = nn.Sequential(
331 nn.Linear(2 * self.nc, self.nh),
333 nn.Linear(self.nh, 1)
336 def forward(self, a, b):
337 a = a.view(a.size(0), 1, a.size(1))
338 a = self.features_a(a)
339 a = F.avg_pool1d(a, a.size(2))
341 b = b.view(b.size(0), 1, b.size(1))
342 b = self.features_b(b)
343 b = F.avg_pool1d(b, b.size(2))
345 x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
346 return self.fully_connected(x)
348 ######################################################################
350 if args.data == 'image_pair':
351 create_pairs = create_image_pairs
352 model = NetForImagePair()
354 elif args.data == 'image_values_pair':
355 create_pairs = create_image_values_pairs
356 model = NetForImageValuesPair()
358 elif args.data == 'sequence_pair':
359 create_pairs = create_sequences_pairs
360 model = NetForSequencePair()
363 a, b, c = create_pairs()
365 file = open(f'train_{k:02d}.dat', 'w')
366 for i in range(a.size(1)):
367 file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
371 raise Exception('Unknown data ' + args.data)
373 ######################################################################
376 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
380 input_a, input_b, classes = create_pairs(train = True)
382 for e in range(args.nb_epochs):
384 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
386 input_br = input_b[torch.randperm(input_b.size(0))]
390 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
391 input_b.split(args.batch_size),
392 input_br.split(args.batch_size)):
393 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
396 optimizer.zero_grad()
400 acc_mi /= (input_a.size(0) // args.batch_size)
402 print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
406 ######################################################################
409 input_a, input_b, classes = create_pairs(train = False)
411 input_br = input_b[torch.randperm(input_b.size(0))]
415 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
416 input_b.split(args.batch_size),
417 input_br.split(args.batch_size)):
418 mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
421 acc_mi /= (input_a.size(0) // args.batch_size)
423 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
425 ######################################################################