Update.
[pytorch.git] / mi_estimator.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import argparse, math, sys
9 from copy import deepcopy
10
11 import torch, torchvision
12
13 from torch import nn
14 import torch.nn.functional as F
15
16 ######################################################################
17
18 if torch.cuda.is_available():
19     torch.backends.cudnn.benchmark = True
20     device = torch.device('cuda')
21 else:
22     device = torch.device('cpu')
23
24 ######################################################################
25
26 parser = argparse.ArgumentParser(
27     description = '''An implementation of a Mutual Information estimator with a deep model
28
29     Three different toy data-sets are implemented, each consists of
30     pairs of samples, that may be from different spaces:
31
32     (1) Two MNIST images of same class. The "true" MI is the log of the
33     number of used MNIST classes.
34
35     (2) One MNIST image and a pair of real numbers whose difference is
36     the class of the image. The "true" MI is the log of the number of
37     used MNIST classes.
38
39     (3) Two 1d sequences, the first with a single peak, the second with
40     two peaks, and the height of the peak in the first is the
41     difference of timing of the peaks in the second. The "true" MI is
42     the log of the number of possible peak heights.''',
43
44     formatter_class = argparse.ArgumentDefaultsHelpFormatter
45 )
46
47 parser.add_argument('--data',
48                     type = str, default = 'image_pair',
49                     help = 'What data: image_pair, image_values_pair, sequence_pair')
50
51 parser.add_argument('--seed',
52                     type = int, default = 0,
53                     help = 'Random seed (default 0, < 0 is no seeding)')
54
55 parser.add_argument('--mnist_classes',
56                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
57                     help = 'What MNIST classes to use')
58
59 parser.add_argument('--nb_classes',
60                     type = int, default = 2,
61                     help = 'How many classes for sequences')
62
63 parser.add_argument('--nb_epochs',
64                     type = int, default = 50,
65                     help = 'How many epochs')
66
67 parser.add_argument('--batch_size',
68                     type = int, default = 100,
69                     help = 'Batch size')
70
71 parser.add_argument('--learning_rate',
72                     type = float, default = 1e-3,
73                     help = 'Batch size')
74
75 parser.add_argument('--independent', action = 'store_true',
76                     help = 'Should the pair components be independent')
77
78 ######################################################################
79
80 args = parser.parse_args()
81
82 if args.seed >= 0:
83     torch.manual_seed(args.seed)
84
85 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
86
87 ######################################################################
88
89 def entropy(target):
90     probas = []
91     for k in range(target.max() + 1):
92         n = (target == k).sum().item()
93         if n > 0: probas.append(n)
94     probas = torch.tensor(probas).float()
95     probas /= probas.sum()
96     return - (probas * probas.log()).sum().item()
97
98 ######################################################################
99
100 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
101 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
102 train_target = train_set.train_labels.to(device)
103
104 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
105 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
106 test_target = test_set.test_labels.to(device)
107
108 mu, std = train_input.mean(), train_input.std()
109 train_input.sub_(mu).div_(std)
110 test_input.sub_(mu).div_(std)
111
112 ######################################################################
113
114 # Returns a triplet of tensors (a, b, c), where a and b contain each
115 # half of the samples, with a[i] and b[i] of same class for any i, and
116 # c is a 1d long tensor real classes
117
118 def create_image_pairs(train = False):
119     ua, ub, uc = [], [], []
120
121     if train:
122         input, target = train_input, train_target
123     else:
124         input, target = test_input, test_target
125
126     for i in used_MNIST_classes:
127         used_indices = torch.arange(input.size(0), device = target.device)\
128                             .masked_select(target == i.item())
129         x = input[used_indices]
130         x = x[torch.randperm(x.size(0))]
131         hs = x.size(0)//2
132         ua.append(x.narrow(0, 0, hs))
133         ub.append(x.narrow(0, hs, hs))
134         uc.append(target[used_indices])
135
136     a = torch.cat(ua, 0)
137     b = torch.cat(ub, 0)
138     c = torch.cat(uc, 0)
139     perm = torch.randperm(a.size(0))
140     a = a[perm].contiguous()
141
142     if args.independent:
143         perm = torch.randperm(a.size(0))
144     b = b[perm].contiguous()
145
146     return a, b, c
147
148 ######################################################################
149
150 # Returns a triplet a, b, c where a are the standard MNIST images, c
151 # the classes, and b is a Nx2 tensor, with for every n:
152 #
153 #   b[n, 0] ~ Uniform(0, 10)
154 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
155
156 def create_image_values_pairs(train = False):
157     ua, ub = [], []
158
159     if train:
160         input, target = train_input, train_target
161     else:
162         input, target = test_input, test_target
163
164     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
165     m[used_MNIST_classes] = 1
166     m = m[target]
167     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
168
169     input = input[used_indices].contiguous()
170     target = target[used_indices].contiguous()
171
172     a = input
173     c = target
174
175     b = a.new(a.size(0), 2)
176     b[:, 0].uniform_(0.0, 10.0)
177     b[:, 1].uniform_(0.0, 0.5)
178
179     if args.independent:
180         b[:, 1] += b[:, 0] + \
181                    used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
182     else:
183         b[:, 1] += b[:, 0] + target.float()
184
185     return a, b, c
186
187 ######################################################################
188
189 #
190
191 def create_sequences_pairs(train = False):
192     nb, length = 10000, 1024
193     noise_level = 2e-2
194
195     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
196     if args.independent:
197         hb = torch.randint(args.nb_classes, (nb, ), device = device)
198     else:
199         hb = ha
200
201     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
202     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
203     a = a - pos.view(nb, 1)
204     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
205     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
206     noise = a.new(a.size()).normal_(0, noise_level)
207     a = a + noise
208
209     pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
210     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
211     b1 = b1 - pos.view(nb, 1)
212     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
213     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
214     # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
215     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
216     b2 = b2 - pos.view(nb, 1)
217     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
218
219     b = b1 + b2
220     noise = b.new(b.size()).normal_(0, noise_level)
221     b = b + noise
222
223     return a, b, ha
224
225 ######################################################################
226
227 class NetForImagePair(nn.Module):
228     def __init__(self):
229         super().__init__()
230         self.features_a = nn.Sequential(
231             nn.Conv2d(1, 16, kernel_size = 5),
232             nn.MaxPool2d(3), nn.ReLU(),
233             nn.Conv2d(16, 32, kernel_size = 5),
234             nn.MaxPool2d(2), nn.ReLU(),
235         )
236
237         self.features_b = nn.Sequential(
238             nn.Conv2d(1, 16, kernel_size = 5),
239             nn.MaxPool2d(3), nn.ReLU(),
240             nn.Conv2d(16, 32, kernel_size = 5),
241             nn.MaxPool2d(2), nn.ReLU(),
242         )
243
244         self.fully_connected = nn.Sequential(
245             nn.Linear(256, 200),
246             nn.ReLU(),
247             nn.Linear(200, 1)
248         )
249
250     def forward(self, a, b):
251         a = self.features_a(a).view(a.size(0), -1)
252         b = self.features_b(b).view(b.size(0), -1)
253         x = torch.cat((a, b), 1)
254         return self.fully_connected(x)
255
256 ######################################################################
257
258 class NetForImageValuesPair(nn.Module):
259     def __init__(self):
260         super().__init__()
261         self.features_a = nn.Sequential(
262             nn.Conv2d(1, 16, kernel_size = 5),
263             nn.MaxPool2d(3), nn.ReLU(),
264             nn.Conv2d(16, 32, kernel_size = 5),
265             nn.MaxPool2d(2), nn.ReLU(),
266         )
267
268         self.features_b = nn.Sequential(
269             nn.Linear(2, 32), nn.ReLU(),
270             nn.Linear(32, 32), nn.ReLU(),
271             nn.Linear(32, 128), nn.ReLU(),
272         )
273
274         self.fully_connected = nn.Sequential(
275             nn.Linear(256, 200),
276             nn.ReLU(),
277             nn.Linear(200, 1)
278         )
279
280     def forward(self, a, b):
281         a = self.features_a(a).view(a.size(0), -1)
282         b = self.features_b(b).view(b.size(0), -1)
283         x = torch.cat((a, b), 1)
284         return self.fully_connected(x)
285
286 ######################################################################
287
288 class NetForSequencePair(nn.Module):
289
290     def feature_model(self):
291         kernel_size = 11
292         pooling_size = 4
293         return  nn.Sequential(
294             nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
295             nn.AvgPool1d(pooling_size),
296             nn.LeakyReLU(),
297             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
298             nn.AvgPool1d(pooling_size),
299             nn.LeakyReLU(),
300             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
301             nn.AvgPool1d(pooling_size),
302             nn.LeakyReLU(),
303             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
304             nn.AvgPool1d(pooling_size),
305             nn.LeakyReLU(),
306         )
307
308     def __init__(self):
309         super().__init__()
310
311         self.nc = 32
312         self.nh = 256
313
314         self.features_a = self.feature_model()
315         self.features_b = self.feature_model()
316
317         self.fully_connected = nn.Sequential(
318             nn.Linear(2 * self.nc, self.nh),
319             nn.ReLU(),
320             nn.Linear(self.nh, 1)
321         )
322
323     def forward(self, a, b):
324         a = a.view(a.size(0), 1, a.size(1))
325         a = self.features_a(a)
326         a = F.avg_pool1d(a, a.size(2))
327
328         b = b.view(b.size(0), 1, b.size(1))
329         b = self.features_b(b)
330         b = F.avg_pool1d(b, b.size(2))
331
332         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
333         return self.fully_connected(x)
334
335 ######################################################################
336
337 if args.data == 'image_pair':
338     create_pairs = create_image_pairs
339     model = NetForImagePair()
340
341 elif args.data == 'image_values_pair':
342     create_pairs = create_image_values_pairs
343     model = NetForImageValuesPair()
344
345 elif args.data == 'sequence_pair':
346     create_pairs = create_sequences_pairs
347     model = NetForSequencePair()
348
349     ######################
350     ## Save for figures
351     a, b, c = create_pairs()
352     for k in range(10):
353         file = open(f'train_{k:02d}.dat', 'w')
354         for i in range(a.size(1)):
355             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
356         file.close()
357     ######################
358
359 else:
360     raise Exception('Unknown data ' + args.data)
361
362 ######################################################################
363 # Train
364
365 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
366
367 model.to(device)
368
369 input_a, input_b, classes = create_pairs(train = True)
370
371 for e in range(args.nb_epochs):
372
373     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
374
375     input_br = input_b[torch.randperm(input_b.size(0))]
376
377     acc_mi = 0.0
378
379     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
380                                           input_b.split(args.batch_size),
381                                           input_br.split(args.batch_size)):
382         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
383         acc_mi += mi.item()
384         loss = - mi
385         optimizer.zero_grad()
386         loss.backward()
387         optimizer.step()
388
389     acc_mi /= (input_a.size(0) // args.batch_size)
390
391     print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
392
393     sys.stdout.flush()
394
395 ######################################################################
396 # Test
397
398 input_a, input_b, classes = create_pairs(train = False)
399
400 input_br = input_b[torch.randperm(input_b.size(0))]
401
402 acc_mi = 0.0
403
404 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
405                                       input_b.split(args.batch_size),
406                                       input_br.split(args.batch_size)):
407     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
408     acc_mi += mi.item()
409
410 acc_mi /= (input_a.size(0) // args.batch_size)
411
412 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
413
414 ######################################################################