Initial commit.
[pytorch.git] / mi_estimator.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by Francois Fleuret, (C) Idiap Research Institute             #
17 #                                                                       #
18 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
19 #########################################################################
20
21 import argparse, math, sys
22 from copy import deepcopy
23
24 import torch, torchvision
25
26 from torch import nn
27 import torch.nn.functional as F
28
29 ######################################################################
30
31 if torch.cuda.is_available():
32     torch.backends.cudnn.benchmark = True
33     device = torch.device('cuda')
34 else:
35     device = torch.device('cpu')
36
37 ######################################################################
38
39 parser = argparse.ArgumentParser(
40     description = '''An implementation of a Mutual Information estimator with a deep model
41
42     Three different toy data-sets are implemented, each consists of
43     pairs of samples, that may be from different spaces:
44
45     (1) Two MNIST images of same class. The "true" MI is the log of the
46     number of used MNIST classes.
47
48     (2) One MNIST image and a pair of real numbers whose difference is
49     the class of the image. The "true" MI is the log of the number of
50     used MNIST classes.
51
52     (3) Two 1d sequences, the first with a single peak, the second with
53     two peaks, and the height of the peak in the first is the
54     difference of timing of the peaks in the second. The "true" MI is
55     the log of the number of possible peak heights.''',
56
57     formatter_class = argparse.ArgumentDefaultsHelpFormatter
58 )
59
60 parser.add_argument('--data',
61                     type = str, default = 'image_pair',
62                     help = 'What data: image_pair, image_values_pair, sequence_pair')
63
64 parser.add_argument('--seed',
65                     type = int, default = 0,
66                     help = 'Random seed (default 0, < 0 is no seeding)')
67
68 parser.add_argument('--mnist_classes',
69                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
70                     help = 'What MNIST classes to use')
71
72 parser.add_argument('--nb_classes',
73                     type = int, default = 2,
74                     help = 'How many classes for sequences')
75
76 parser.add_argument('--nb_epochs',
77                     type = int, default = 50,
78                     help = 'How many epochs')
79
80 parser.add_argument('--batch_size',
81                     type = int, default = 100,
82                     help = 'Batch size')
83
84 parser.add_argument('--learning_rate',
85                     type = float, default = 1e-3,
86                     help = 'Batch size')
87
88 parser.add_argument('--independent', action = 'store_true',
89                     help = 'Should the pair components be independent')
90
91 ######################################################################
92
93 args = parser.parse_args()
94
95 if args.seed >= 0:
96     torch.manual_seed(args.seed)
97
98 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
99
100 ######################################################################
101
102 def entropy(target):
103     probas = []
104     for k in range(target.max() + 1):
105         n = (target == k).sum().item()
106         if n > 0: probas.append(n)
107     probas = torch.tensor(probas).float()
108     probas /= probas.sum()
109     return - (probas * probas.log()).sum().item()
110
111 ######################################################################
112
113 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
114 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
115 train_target = train_set.train_labels.to(device)
116
117 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
118 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
119 test_target = test_set.test_labels.to(device)
120
121 mu, std = train_input.mean(), train_input.std()
122 train_input.sub_(mu).div_(std)
123 test_input.sub_(mu).div_(std)
124
125 ######################################################################
126
127 # Returns a triplet of tensors (a, b, c), where a and b contain each
128 # half of the samples, with a[i] and b[i] of same class for any i, and
129 # c is a 1d long tensor real classes
130
131 def create_image_pairs(train = False):
132     ua, ub, uc = [], [], []
133
134     if train:
135         input, target = train_input, train_target
136     else:
137         input, target = test_input, test_target
138
139     for i in used_MNIST_classes:
140         used_indices = torch.arange(input.size(0), device = target.device)\
141                             .masked_select(target == i.item())
142         x = input[used_indices]
143         x = x[torch.randperm(x.size(0))]
144         hs = x.size(0)//2
145         ua.append(x.narrow(0, 0, hs))
146         ub.append(x.narrow(0, hs, hs))
147         uc.append(target[used_indices])
148
149     a = torch.cat(ua, 0)
150     b = torch.cat(ub, 0)
151     c = torch.cat(uc, 0)
152     perm = torch.randperm(a.size(0))
153     a = a[perm].contiguous()
154
155     if args.independent:
156         perm = torch.randperm(a.size(0))
157     b = b[perm].contiguous()
158
159     return a, b, c
160
161 ######################################################################
162
163 # Returns a triplet a, b, c where a are the standard MNIST images, c
164 # the classes, and b is a Nx2 tensor, with for every n:
165 #
166 #   b[n, 0] ~ Uniform(0, 10)
167 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
168
169 def create_image_values_pairs(train = False):
170     ua, ub = [], []
171
172     if train:
173         input, target = train_input, train_target
174     else:
175         input, target = test_input, test_target
176
177     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
178     m[used_MNIST_classes] = 1
179     m = m[target]
180     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
181
182     input = input[used_indices].contiguous()
183     target = target[used_indices].contiguous()
184
185     a = input
186     c = target
187
188     b = a.new(a.size(0), 2)
189     b[:, 0].uniform_(0.0, 10.0)
190     b[:, 1].uniform_(0.0, 0.5)
191
192     if args.independent:
193         b[:, 1] += b[:, 0] + \
194                    used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
195     else:
196         b[:, 1] += b[:, 0] + target.float()
197
198     return a, b, c
199
200 ######################################################################
201
202 #
203
204 def create_sequences_pairs(train = False):
205     nb, length = 10000, 1024
206     noise_level = 2e-2
207
208     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
209     if args.independent:
210         hb = torch.randint(args.nb_classes, (nb, ), device = device)
211     else:
212         hb = ha
213
214     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
215     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
216     a = a - pos.view(nb, 1)
217     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
218     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
219     noise = a.new(a.size()).normal_(0, noise_level)
220     a = a + noise
221
222     pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
223     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
224     b1 = b1 - pos.view(nb, 1)
225     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
226     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
227     # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
228     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
229     b2 = b2 - pos.view(nb, 1)
230     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
231
232     b = b1 + b2
233     noise = b.new(b.size()).normal_(0, noise_level)
234     b = b + noise
235
236     return a, b, ha
237
238 ######################################################################
239
240 class NetForImagePair(nn.Module):
241     def __init__(self):
242         super(NetForImagePair, self).__init__()
243         self.features_a = nn.Sequential(
244             nn.Conv2d(1, 16, kernel_size = 5),
245             nn.MaxPool2d(3), nn.ReLU(),
246             nn.Conv2d(16, 32, kernel_size = 5),
247             nn.MaxPool2d(2), nn.ReLU(),
248         )
249
250         self.features_b = nn.Sequential(
251             nn.Conv2d(1, 16, kernel_size = 5),
252             nn.MaxPool2d(3), nn.ReLU(),
253             nn.Conv2d(16, 32, kernel_size = 5),
254             nn.MaxPool2d(2), nn.ReLU(),
255         )
256
257         self.fully_connected = nn.Sequential(
258             nn.Linear(256, 200),
259             nn.ReLU(),
260             nn.Linear(200, 1)
261         )
262
263     def forward(self, a, b):
264         a = self.features_a(a).view(a.size(0), -1)
265         b = self.features_b(b).view(b.size(0), -1)
266         x = torch.cat((a, b), 1)
267         return self.fully_connected(x)
268
269 ######################################################################
270
271 class NetForImageValuesPair(nn.Module):
272     def __init__(self):
273         super(NetForImageValuesPair, self).__init__()
274         self.features_a = nn.Sequential(
275             nn.Conv2d(1, 16, kernel_size = 5),
276             nn.MaxPool2d(3), nn.ReLU(),
277             nn.Conv2d(16, 32, kernel_size = 5),
278             nn.MaxPool2d(2), nn.ReLU(),
279         )
280
281         self.features_b = nn.Sequential(
282             nn.Linear(2, 32), nn.ReLU(),
283             nn.Linear(32, 32), nn.ReLU(),
284             nn.Linear(32, 128), nn.ReLU(),
285         )
286
287         self.fully_connected = nn.Sequential(
288             nn.Linear(256, 200),
289             nn.ReLU(),
290             nn.Linear(200, 1)
291         )
292
293     def forward(self, a, b):
294         a = self.features_a(a).view(a.size(0), -1)
295         b = self.features_b(b).view(b.size(0), -1)
296         x = torch.cat((a, b), 1)
297         return self.fully_connected(x)
298
299 ######################################################################
300
301 class NetForSequencePair(nn.Module):
302
303     def feature_model(self):
304         kernel_size = 11
305         pooling_size = 4
306         return  nn.Sequential(
307             nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
308             nn.AvgPool1d(pooling_size),
309             nn.LeakyReLU(),
310             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
311             nn.AvgPool1d(pooling_size),
312             nn.LeakyReLU(),
313             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
314             nn.AvgPool1d(pooling_size),
315             nn.LeakyReLU(),
316             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
317             nn.AvgPool1d(pooling_size),
318             nn.LeakyReLU(),
319         )
320
321     def __init__(self):
322         super(NetForSequencePair, self).__init__()
323
324         self.nc = 32
325         self.nh = 256
326
327         self.features_a = self.feature_model()
328         self.features_b = self.feature_model()
329
330         self.fully_connected = nn.Sequential(
331             nn.Linear(2 * self.nc, self.nh),
332             nn.ReLU(),
333             nn.Linear(self.nh, 1)
334         )
335
336     def forward(self, a, b):
337         a = a.view(a.size(0), 1, a.size(1))
338         a = self.features_a(a)
339         a = F.avg_pool1d(a, a.size(2))
340
341         b = b.view(b.size(0), 1, b.size(1))
342         b = self.features_b(b)
343         b = F.avg_pool1d(b, b.size(2))
344
345         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
346         return self.fully_connected(x)
347
348 ######################################################################
349
350 if args.data == 'image_pair':
351     create_pairs = create_image_pairs
352     model = NetForImagePair()
353
354 elif args.data == 'image_values_pair':
355     create_pairs = create_image_values_pairs
356     model = NetForImageValuesPair()
357
358 elif args.data == 'sequence_pair':
359     create_pairs = create_sequences_pairs
360     model = NetForSequencePair()
361
362     ######################
363     ## Save for figures
364     a, b, c = create_pairs()
365     for k in range(10):
366         file = open(f'train_{k:02d}.dat', 'w')
367         for i in range(a.size(1)):
368             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
369         file.close()
370     ######################
371
372 else:
373     raise Exception('Unknown data ' + args.data)
374
375 ######################################################################
376 # Train
377
378 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
379
380 model.to(device)
381
382 input_a, input_b, classes = create_pairs(train = True)
383
384 for e in range(args.nb_epochs):
385
386     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
387
388     input_br = input_b[torch.randperm(input_b.size(0))]
389
390     acc_mi = 0.0
391
392     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
393                                           input_b.split(args.batch_size),
394                                           input_br.split(args.batch_size)):
395         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
396         acc_mi += mi.item()
397         loss = - mi
398         optimizer.zero_grad()
399         loss.backward()
400         optimizer.step()
401
402     acc_mi /= (input_a.size(0) // args.batch_size)
403
404     print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
405
406     sys.stdout.flush()
407
408 ######################################################################
409 # Test
410
411 input_a, input_b, classes = create_pairs(train = False)
412
413 input_br = input_b[torch.randperm(input_b.size(0))]
414
415 acc_mi = 0.0
416
417 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
418                                       input_b.split(args.batch_size),
419                                       input_br.split(args.batch_size)):
420     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
421     acc_mi += mi.item()
422
423 acc_mi /= (input_a.size(0) // args.batch_size)
424
425 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
426
427 ######################################################################