Update.
[pytorch.git] / mine_mnist.py
1 #!/usr/bin/env python
2
3 #########################################################################
4 # This program is free software: you can redistribute it and/or modify  #
5 # it under the terms of the version 3 of the GNU General Public License #
6 # as published by the Free Software Foundation.                         #
7 #                                                                       #
8 # This program is distributed in the hope that it will be useful, but   #
9 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
11 # General Public License for more details.                              #
12 #                                                                       #
13 # You should have received a copy of the GNU General Public License     #
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
15 #                                                                       #
16 # Written by and Copyright (C) Francois Fleuret                         #
17 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
18 #########################################################################
19
20 import argparse, math, sys
21 from copy import deepcopy
22
23 import torch, torchvision
24
25 from torch import nn
26 import torch.nn.functional as F
27
28 ######################################################################
29
30 if torch.cuda.is_available():
31     torch.backends.cudnn.benchmark = True
32     device = torch.device('cuda')
33 else:
34     device = torch.device('cpu')
35
36 ######################################################################
37
38 parser = argparse.ArgumentParser(
39     description = '''An implementation of a Mutual Information estimator with a deep model
40
41 Three different toy data-sets are implemented:
42
43  (1) Two MNIST images of same class. The "true" MI is the log of the
44      number of used MNIST classes.
45
46  (2) One MNIST image and a pair of real numbers whose difference is
47      the class of the image. The "true" MI is the log of the number of
48      used MNIST classes.
49
50  (3) Two 1d sequences, the first with a single peak, the second with
51      two peaks, and the height of the peak in the first is the
52      difference of timing of the peaks in the second. The "true" MI is
53      the log of the number of possible peak heights.''',
54
55     formatter_class = argparse.ArgumentDefaultsHelpFormatter
56 )
57
58 parser.add_argument('--data',
59                     type = str, default = 'image_pair',
60                     help = 'What data: image_pair, image_values_pair, sequence_pair')
61
62 parser.add_argument('--seed',
63                     type = int, default = 0,
64                     help = 'Random seed (default 0, < 0 is no seeding)')
65
66 parser.add_argument('--mnist_classes',
67                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
68                     help = 'What MNIST classes to use')
69
70 parser.add_argument('--nb_classes',
71                     type = int, default = 2,
72                     help = 'How many classes for sequences')
73
74 parser.add_argument('--nb_epochs',
75                     type = int, default = 50,
76                     help = 'How many epochs')
77
78 parser.add_argument('--batch_size',
79                     type = int, default = 100,
80                     help = 'Batch size')
81
82 parser.add_argument('--learning_rate',
83                     type = float, default = 1e-3,
84                     help = 'Batch size')
85
86 parser.add_argument('--independent', action = 'store_true',
87                     help = 'Should the pair components be independent')
88
89
90 ######################################################################
91
92 args = parser.parse_args()
93
94 if args.seed >= 0:
95     torch.manual_seed(args.seed)
96
97 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device)
98
99 ######################################################################
100
101 def entropy(target):
102     probas = []
103     for k in range(target.max() + 1):
104         n = (target == k).sum().item()
105         if n > 0: probas.append(n)
106     probas = torch.tensor(probas).float()
107     probas /= probas.sum()
108     return - (probas * probas.log()).sum().item()
109
110 ######################################################################
111
112 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
113 train_input  = train_set.train_data.view(-1, 1, 28, 28).to(device).float()
114 train_target = train_set.train_labels.to(device)
115
116 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
117 test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
118 test_target = test_set.test_labels.to(device)
119
120 mu, std = train_input.mean(), train_input.std()
121 train_input.sub_(mu).div_(std)
122 test_input.sub_(mu).div_(std)
123
124 ######################################################################
125
126 # Returns a triplet of tensors (a, b, c), where a and b contain each
127 # half of the samples, with a[i] and b[i] of same class for any i, and
128 # c is a 1d long tensor real classes
129
130 def create_image_pairs(train = False):
131     ua, ub, uc = [], [], []
132
133     if train:
134         input, target = train_input, train_target
135     else:
136         input, target = test_input, test_target
137
138     for i in used_MNIST_classes:
139         used_indices = torch.arange(input.size(0), device = target.device)\
140                             .masked_select(target == i.item())
141         x = input[used_indices]
142         x = x[torch.randperm(x.size(0))]
143         hs = x.size(0)//2
144         ua.append(x.narrow(0, 0, hs))
145         ub.append(x.narrow(0, hs, hs))
146         uc.append(target[used_indices])
147
148     a = torch.cat(ua, 0)
149     b = torch.cat(ub, 0)
150     c = torch.cat(uc, 0)
151     perm = torch.randperm(a.size(0))
152     a = a[perm].contiguous()
153
154     if args.independent:
155         perm = torch.randperm(a.size(0))
156     b = b[perm].contiguous()
157
158     return a, b, c
159
160 ######################################################################
161
162 # Returns a triplet a, b, c where a are the standard MNIST images, c
163 # the classes, and b is a Nx2 tensor, with for every n:
164 #
165 #   b[n, 0] ~ Uniform(0, 10)
166 #   b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
167
168 def create_image_values_pairs(train = False):
169     ua, ub = [], []
170
171     if train:
172         input, target = train_input, train_target
173     else:
174         input, target = test_input, test_target
175
176     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
177     m[used_MNIST_classes] = 1
178     m = m[target]
179     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
180
181     input = input[used_indices].contiguous()
182     target = target[used_indices].contiguous()
183
184     a = input
185     c = target
186
187     b = a.new(a.size(0), 2)
188     b[:, 0].uniform_(0.0, 10.0)
189     b[:, 1].uniform_(0.0, 0.5)
190
191     if args.independent:
192         b[:, 1] += b[:, 0] + \
193                    used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
194     else:
195         b[:, 1] += b[:, 0] + target.float()
196
197     return a, b, c
198
199 ######################################################################
200
201 def create_sequences_pairs(train = False):
202     nb, length = 10000, 1024
203     noise_level = 2e-2
204
205     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
206     if args.independent:
207         hb = torch.randint(args.nb_classes, (nb, ), device = device)
208     else:
209         hb = ha
210
211     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
212     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
213     a = a - pos.view(nb, 1)
214     a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
215     a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
216     noise = a.new(a.size()).normal_(0, noise_level)
217     a = a + noise
218
219     pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
220     b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
221     b1 = b1 - pos.view(nb, 1)
222     b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
223     pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
224     # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
225     b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
226     b2 = b2 - pos.view(nb, 1)
227     b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
228
229     b = b1 + b2
230     noise = b.new(b.size()).normal_(0, noise_level)
231     b = b + noise
232
233     # a = (a - a.mean()) / a.std()
234     # b = (b - b.mean()) / b.std()
235
236     return a, b, ha
237
238 ######################################################################
239
240 class NetForImagePair(nn.Module):
241     def __init__(self):
242         super(NetForImagePair, self).__init__()
243         self.features_a = nn.Sequential(
244             nn.Conv2d(1, 16, kernel_size = 5),
245             nn.MaxPool2d(3), nn.ReLU(),
246             nn.Conv2d(16, 32, kernel_size = 5),
247             nn.MaxPool2d(2), nn.ReLU(),
248         )
249
250         self.features_b = nn.Sequential(
251             nn.Conv2d(1, 16, kernel_size = 5),
252             nn.MaxPool2d(3), nn.ReLU(),
253             nn.Conv2d(16, 32, kernel_size = 5),
254             nn.MaxPool2d(2), nn.ReLU(),
255         )
256
257         self.fully_connected = nn.Sequential(
258             nn.Linear(256, 200),
259             nn.ReLU(),
260             nn.Linear(200, 1)
261         )
262
263     def forward(self, a, b):
264         a = self.features_a(a).view(a.size(0), -1)
265         b = self.features_b(b).view(b.size(0), -1)
266         x = torch.cat((a, b), 1)
267         return self.fully_connected(x)
268
269 ######################################################################
270
271 class NetForImageValuesPair(nn.Module):
272     def __init__(self):
273         super(NetForImageValuesPair, self).__init__()
274         self.features_a = nn.Sequential(
275             nn.Conv2d(1, 16, kernel_size = 5),
276             nn.MaxPool2d(3), nn.ReLU(),
277             nn.Conv2d(16, 32, kernel_size = 5),
278             nn.MaxPool2d(2), nn.ReLU(),
279         )
280
281         self.features_b = nn.Sequential(
282             nn.Linear(2, 32), nn.ReLU(),
283             nn.Linear(32, 32), nn.ReLU(),
284             nn.Linear(32, 128), nn.ReLU(),
285         )
286
287         self.fully_connected = nn.Sequential(
288             nn.Linear(256, 200),
289             nn.ReLU(),
290             nn.Linear(200, 1)
291         )
292
293     def forward(self, a, b):
294         a = self.features_a(a).view(a.size(0), -1)
295         b = self.features_b(b).view(b.size(0), -1)
296         x = torch.cat((a, b), 1)
297         return self.fully_connected(x)
298
299 ######################################################################
300
301 class NetForSequencePair(nn.Module):
302
303     def feature_model(self):
304         kernel_size = 11
305         pooling_size = 4
306         return  nn.Sequential(
307             nn.Conv1d(      1, self.nc, kernel_size = kernel_size),
308             nn.AvgPool1d(pooling_size),
309             nn.LeakyReLU(),
310             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
311             nn.AvgPool1d(pooling_size),
312             nn.LeakyReLU(),
313             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
314             nn.AvgPool1d(pooling_size),
315             nn.LeakyReLU(),
316             nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
317             nn.AvgPool1d(pooling_size),
318             nn.LeakyReLU(),
319         )
320
321     def __init__(self):
322         super(NetForSequencePair, self).__init__()
323
324         self.nc = 32
325         self.nh = 256
326
327         self.features_a = self.feature_model()
328         self.features_b = self.feature_model()
329
330         self.fully_connected = nn.Sequential(
331             nn.Linear(2 * self.nc, self.nh),
332             nn.ReLU(),
333             nn.Linear(self.nh, 1)
334         )
335
336     def forward(self, a, b):
337         a = a.view(a.size(0), 1, a.size(1))
338         a = self.features_a(a)
339         a = F.avg_pool1d(a, a.size(2))
340
341         b = b.view(b.size(0), 1, b.size(1))
342         b = self.features_b(b)
343         b = F.avg_pool1d(b, b.size(2))
344
345         x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1)
346         return self.fully_connected(x)
347
348 ######################################################################
349
350 if args.data == 'image_pair':
351     create_pairs = create_image_pairs
352     model = NetForImagePair()
353
354 elif args.data == 'image_values_pair':
355     create_pairs = create_image_values_pairs
356     model = NetForImageValuesPair()
357
358 elif args.data == 'sequence_pair':
359     create_pairs = create_sequences_pairs
360     model = NetForSequencePair()
361
362     ## Save for figures
363     a, b, c = create_pairs()
364     for k in range(10):
365         file = open(f'train_{k:02d}.dat', 'w')
366         for i in range(a.size(1)):
367             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
368         file.close()
369
370 else:
371     raise Exception('Unknown data ' + args.data)
372
373 ######################################################################
374 # Train
375
376 print(f'nb_parameters {sum(x.numel() for x in model.parameters())}')
377
378 model.to(device)
379
380 input_a, input_b, classes = create_pairs(train = True)
381
382 for e in range(args.nb_epochs):
383
384     optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
385
386     input_br = input_b[torch.randperm(input_b.size(0))]
387
388     acc_mi = 0.0
389
390     for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
391                                           input_b.split(args.batch_size),
392                                           input_br.split(args.batch_size)):
393         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
394         acc_mi += mi.item()
395         loss = - mi
396         optimizer.zero_grad()
397         loss.backward()
398         optimizer.step()
399
400     acc_mi /= (input_a.size(0) // args.batch_size)
401
402     print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
403
404     sys.stdout.flush()
405
406 ######################################################################
407 # Test
408
409 input_a, input_b, classes = create_pairs(train = False)
410
411 input_br = input_b[torch.randperm(input_b.size(0))]
412
413 acc_mi = 0.0
414
415 for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
416                                       input_b.split(args.batch_size),
417                                       input_br.split(args.batch_size)):
418     mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
419     acc_mi += mi.item()
420
421 acc_mi /= (input_a.size(0) // args.batch_size)
422
423 print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}')
424
425 ######################################################################