412c6242f423fde815bc1eeb8d853401323e3afc
[pytorch.git] / mine_mnist.py
1 #!/usr/bin/env python
2
3 import argparse
4
5 import math, sys, torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12 parser = argparse.ArgumentParser(
13     description = 'An implementation of Mutual Information estimator with a deep model',
14     formatter_class = argparse.ArgumentDefaultsHelpFormatter
15 )
16
17 parser.add_argument('--data',
18                     type = str, default = 'image_pair',
19                     help = 'What data')
20
21 parser.add_argument('--seed',
22                     type = int, default = 0,
23                     help = 'Random seed (default 0, < 0 is no seeding)')
24
25 parser.add_argument('--mnist_classes',
26                     type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
27                     help = 'What MNIST classes to use')
28
29 ######################################################################
30
31 args = parser.parse_args()
32
33 if args.seed >= 0:
34     torch.manual_seed(args.seed)
35
36 used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'))
37
38 ######################################################################
39
40 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
41 train_input  = train_set.train_data.view(-1, 1, 28, 28).float()
42 train_target = train_set.train_labels
43
44 test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True)
45 test_input = test_set.test_data.view(-1, 1, 28, 28).float()
46 test_target = test_set.test_labels
47
48 if torch.cuda.is_available():
49     used_MNIST_classes = used_MNIST_classes.cuda()
50     train_input, train_target = train_input.cuda(), train_target.cuda()
51     test_input, test_target = test_input.cuda(), test_target.cuda()
52
53 mu, std = train_input.mean(), train_input.std()
54 train_input.sub_(mu).div_(std)
55 test_input.sub_(mu).div_(std)
56
57 ######################################################################
58
59 # Returns a triplet of tensors (a, b, c), where a and b contain each
60 # half of the samples, with a[i] and b[i] of same class for any i, and
61 # c is a 1d long tensor with the count of pairs per class used.
62
63 def create_image_pairs(train = False):
64     ua, ub = [], []
65
66     if train:
67         input, target = train_input, train_target
68     else:
69         input, target = test_input, test_target
70
71     for i in used_MNIST_classes:
72         used_indices = torch.arange(input.size(0), device = target.device)\
73                             .masked_select(target == i.item())
74         x = input[used_indices]
75         x = x[torch.randperm(x.size(0))]
76         hs = x.size(0)//2
77         ua.append(x.narrow(0, 0, hs))
78         ub.append(x.narrow(0, hs, hs))
79
80     a = torch.cat(ua, 0)
81     b = torch.cat(ub, 0)
82     perm = torch.randperm(a.size(0))
83     a = a[perm].contiguous()
84     b = b[perm].contiguous()
85     c = torch.tensor([x.size(0) for x in ua])
86
87     return a, b, c
88
89 ######################################################################
90
91 def create_image_values_pairs(train = False):
92     ua, ub = [], []
93
94     if train:
95         input, target = train_input, train_target
96     else:
97         input, target = test_input, test_target
98
99     m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
100     m[used_MNIST_classes] = 1
101     m = m[target]
102     used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
103
104     input = input[used_indices].contiguous()
105     target = target[used_indices].contiguous()
106
107     a = input
108
109     b = a.new(a.size(0), 2)
110     b[:, 0].uniform_(10)
111     b[:, 1].uniform_(0.5)
112     b[:, 1] += b[:, 0] + target.float()
113
114     c = torch.tensor([(target == k).sum().item() for k in used_MNIST_classes])
115
116     return a, b, c
117
118 ######################################################################
119
120 class NetImagePair(nn.Module):
121     def __init__(self):
122         super(NetImagePair, self).__init__()
123         self.features_a = nn.Sequential(
124             nn.Conv2d(1, 16, kernel_size = 5),
125             nn.MaxPool2d(3), nn.ReLU(),
126             nn.Conv2d(16, 32, kernel_size = 5),
127             nn.MaxPool2d(2), nn.ReLU(),
128         )
129
130         self.features_b = nn.Sequential(
131             nn.Conv2d(1, 16, kernel_size = 5),
132             nn.MaxPool2d(3), nn.ReLU(),
133             nn.Conv2d(16, 32, kernel_size = 5),
134             nn.MaxPool2d(2), nn.ReLU(),
135         )
136
137         self.fully_connected = nn.Sequential(
138             nn.Linear(256, 200),
139             nn.ReLU(),
140             nn.Linear(200, 1)
141         )
142
143     def forward(self, a, b):
144         a = self.features_a(a).view(a.size(0), -1)
145         b = self.features_b(b).view(b.size(0), -1)
146         x = torch.cat((a, b), 1)
147         return self.fully_connected(x)
148
149 ######################################################################
150
151 class NetImageValuesPair(nn.Module):
152     def __init__(self):
153         super(NetImageValuesPair, self).__init__()
154         self.features_a = nn.Sequential(
155             nn.Conv2d(1, 16, kernel_size = 5),
156             nn.MaxPool2d(3), nn.ReLU(),
157             nn.Conv2d(16, 32, kernel_size = 5),
158             nn.MaxPool2d(2), nn.ReLU(),
159         )
160
161         self.features_b = nn.Sequential(
162             nn.Linear(2, 32), nn.ReLU(),
163             nn.Linear(32, 32), nn.ReLU(),
164             nn.Linear(32, 128), nn.ReLU(),
165         )
166
167         self.fully_connected = nn.Sequential(
168             nn.Linear(256, 200),
169             nn.ReLU(),
170             nn.Linear(200, 1)
171         )
172
173     def forward(self, a, b):
174         a = self.features_a(a).view(a.size(0), -1)
175         b = self.features_b(b).view(b.size(0), -1)
176         x = torch.cat((a, b), 1)
177         return self.fully_connected(x)
178
179 ######################################################################
180
181 if args.data == 'image_pair':
182     create_pairs = create_image_pairs
183     model = NetImagePair()
184 elif args.data == 'image_values_pair':
185     create_pairs = create_image_values_pairs
186     model = NetImageValuesPair()
187 else:
188     raise Exception('Unknown data ' + args.data)
189
190 ######################################################################
191
192 nb_epochs, batch_size = 50, 100
193
194 print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
195
196 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
197
198 if torch.cuda.is_available():
199     model.cuda()
200
201 for e in range(nb_epochs):
202
203     input_a, input_b, count = create_pairs(train = True)
204
205     # The information bound is the entropy of the class distribution
206     class_proba = count.float()
207     class_proba /= class_proba.sum()
208     class_entropy = - (class_proba.log() * class_proba).sum().item()
209
210     input_br = input_b[torch.randperm(input_b.size(0))]
211
212     acc_mi = 0.0
213
214     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
215                                           input_b.split(batch_size),
216                                           input_br.split(batch_size)):
217         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
218         loss = - mi
219         acc_mi += mi.item()
220         optimizer.zero_grad()
221         loss.backward()
222         optimizer.step()
223
224     acc_mi /= (input_a.size(0) // batch_size)
225
226     print('%d %.04f %.04f' % (e, acc_mi / math.log(2), class_entropy / math.log(2)))
227
228     sys.stdout.flush()
229
230 ######################################################################
231
232 input_a, input_b, count = create_pairs(train = False)
233
234 for e in range(nb_epochs):
235     class_proba = count.float()
236     class_proba /= class_proba.sum()
237     class_entropy = - (class_proba.log() * class_proba).sum().item()
238
239     input_br = input_b[torch.randperm(input_b.size(0))]
240
241     acc_mi = 0.0
242
243     for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
244                                           input_b.split(batch_size),
245                                           input_br.split(batch_size)):
246         mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
247         acc_mi += mi.item()
248
249     acc_mi /= (input_a.size(0) // batch_size)
250
251 print('test %.04f %.04f'%(acc_mi / math.log(2), class_entropy / math.log(2)))
252
253 ######################################################################