X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mine_mnist.py;h=1d69640af452af658d5932d4723dccc38c10d056;hb=75267f198e8f6cf476cb73d2846653494d7164b6;hp=412c6242f423fde815bc1eeb8d853401323e3afc;hpb=e7b065a38122910f512b87ac9551b3ac535361a9;p=pytorch.git diff --git a/mine_mnist.py b/mine_mnist.py index 412c624..1d69640 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -1,22 +1,63 @@ #!/usr/bin/env python -import argparse - -import math, sys, torch, torchvision +######################################################################### +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the version 3 of the GNU General Public License # +# as published by the Free Software Foundation. # +# # +# This program is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Written by and Copyright (C) Francois Fleuret # +# Contact for comments & bug reports # +######################################################################### + +import argparse, math, sys +from copy import deepcopy + +import torch, torchvision from torch import nn -from torch.nn import functional as F +import torch.nn.functional as F + +###################################################################### + +if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + device = torch.device('cuda') +else: + device = torch.device('cpu') ###################################################################### parser = argparse.ArgumentParser( - description = 'An implementation of Mutual Information estimator with a deep model', + description = '''An implementation of a Mutual Information estimator with a deep model + +Three different toy data-sets are implemented: + + (1) Two MNIST images of same class. The "true" MI is the log of the + number of used MNIST classes. + + (2) One MNIST image and a pair of real numbers whose difference is + the class of the image. The "true" MI is the log of the number of + used MNIST classes. + + (3) Two 1d sequences, the first with a single peak, the second with + two peaks, and the height of the peak in the first is the + difference of timing of the peaks in the second. The "true" MI is + the log of the number of possible peak heights.''', + formatter_class = argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('--data', type = str, default = 'image_pair', - help = 'What data') + help = 'What data: image_pair, image_values_pair, sequence_pair') parser.add_argument('--seed', type = int, default = 0, @@ -26,6 +67,26 @@ parser.add_argument('--mnist_classes', type = str, default = '0, 1, 3, 5, 6, 7, 8, 9', help = 'What MNIST classes to use') +parser.add_argument('--nb_classes', + type = int, default = 2, + help = 'How many classes for sequences') + +parser.add_argument('--nb_epochs', + type = int, default = 50, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 100, + help = 'Batch size') + +parser.add_argument('--learning_rate', + type = float, default = 1e-3, + help = 'Batch size') + +parser.add_argument('--independent', action = 'store_true', + help = 'Should the pair components be independent') + + ###################################################################### args = parser.parse_args() @@ -33,22 +94,28 @@ args = parser.parse_args() if args.seed >= 0: torch.manual_seed(args.seed) -used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']')) +used_MNIST_classes = torch.tensor(eval('[' + args.mnist_classes + ']'), device = device) + +###################################################################### + +def entropy(target): + probas = [] + for k in range(target.max() + 1): + n = (target == k).sum().item() + if n > 0: probas.append(n) + probas = torch.tensor(probas).float() + probas /= probas.sum() + return - (probas * probas.log()).sum().item() ###################################################################### train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) -train_input = train_set.train_data.view(-1, 1, 28, 28).float() -train_target = train_set.train_labels +train_input = train_set.train_data.view(-1, 1, 28, 28).to(device).float() +train_target = train_set.train_labels.to(device) test_set = torchvision.datasets.MNIST('./data/mnist/', train = False, download = True) -test_input = test_set.test_data.view(-1, 1, 28, 28).float() -test_target = test_set.test_labels - -if torch.cuda.is_available(): - used_MNIST_classes = used_MNIST_classes.cuda() - train_input, train_target = train_input.cuda(), train_target.cuda() - test_input, test_target = test_input.cuda(), test_target.cuda() +test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float() +test_target = test_set.test_labels.to(device) mu, std = train_input.mean(), train_input.std() train_input.sub_(mu).div_(std) @@ -58,10 +125,10 @@ test_input.sub_(mu).div_(std) # Returns a triplet of tensors (a, b, c), where a and b contain each # half of the samples, with a[i] and b[i] of same class for any i, and -# c is a 1d long tensor with the count of pairs per class used. +# c is a 1d long tensor real classes def create_image_pairs(train = False): - ua, ub = [], [] + ua, ub, uc = [], [], [] if train: input, target = train_input, train_target @@ -76,18 +143,28 @@ def create_image_pairs(train = False): hs = x.size(0)//2 ua.append(x.narrow(0, 0, hs)) ub.append(x.narrow(0, hs, hs)) + uc.append(target[used_indices]) a = torch.cat(ua, 0) b = torch.cat(ub, 0) + c = torch.cat(uc, 0) perm = torch.randperm(a.size(0)) a = a[perm].contiguous() + + if args.independent: + perm = torch.randperm(a.size(0)) b = b[perm].contiguous() - c = torch.tensor([x.size(0) for x in ua]) return a, b, c ###################################################################### +# Returns a triplet a, b, c where a are the standard MNIST images, c +# the classes, and b is a Nx2 tensor, with for every n: +# +# b[n, 0] ~ Uniform(0, 10) +# b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n] + def create_image_values_pairs(train = False): ua, ub = [], [] @@ -105,21 +182,64 @@ def create_image_values_pairs(train = False): target = target[used_indices].contiguous() a = input + c = target b = a.new(a.size(0), 2) - b[:, 0].uniform_(10) - b[:, 1].uniform_(0.5) - b[:, 1] += b[:, 0] + target.float() + b[:, 0].uniform_(0.0, 10.0) + b[:, 1].uniform_(0.0, 0.5) - c = torch.tensor([(target == k).sum().item() for k in used_MNIST_classes]) + if args.independent: + b[:, 1] += b[:, 0] + \ + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())] + else: + b[:, 1] += b[:, 0] + target.float() return a, b, c ###################################################################### -class NetImagePair(nn.Module): +def create_sequences_pairs(train = False): + nb, length = 10000, 1024 + noise_level = 2e-2 + + ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1 + if args.independent: + hb = torch.randint(args.nb_classes, (nb, ), device = device) + else: + hb = ha + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.9) + a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + a = a - pos.view(nb, 1) + a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1) + a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes) + noise = a.new(a.size()).normal_(0, noise_level) + a = a + noise + + pos = torch.empty(nb, device = device).uniform_(0.0, 0.5) + b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b1 = b1 - pos.view(nb, 1) + b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25 + pos = pos + hb.float() / (args.nb_classes + 1) * 0.5 + # pos += pos.new(hb.size()).uniform_(0.0, 0.01) + b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) + b2 = b2 - pos.view(nb, 1) + b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25 + + b = b1 + b2 + noise = b.new(b.size()).normal_(0, noise_level) + b = b + noise + + # a = (a - a.mean()) / a.std() + # b = (b - b.mean()) / b.std() + + return a, b, ha + +###################################################################### + +class NetForImagePair(nn.Module): def __init__(self): - super(NetImagePair, self).__init__() + super(NetForImagePair, self).__init__() self.features_a = nn.Sequential( nn.Conv2d(1, 16, kernel_size = 5), nn.MaxPool2d(3), nn.ReLU(), @@ -148,9 +268,9 @@ class NetImagePair(nn.Module): ###################################################################### -class NetImageValuesPair(nn.Module): +class NetForImageValuesPair(nn.Module): def __init__(self): - super(NetImageValuesPair, self).__init__() + super(NetForImageValuesPair, self).__init__() self.features_a = nn.Sequential( nn.Conv2d(1, 16, kernel_size = 5), nn.MaxPool2d(3), nn.ReLU(), @@ -178,76 +298,128 @@ class NetImageValuesPair(nn.Module): ###################################################################### +class NetForSequencePair(nn.Module): + + def feature_model(self): + kernel_size = 11 + pooling_size = 4 + return nn.Sequential( + nn.Conv1d( 1, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size), + nn.AvgPool1d(pooling_size), + nn.LeakyReLU(), + ) + + def __init__(self): + super(NetForSequencePair, self).__init__() + + self.nc = 32 + self.nh = 256 + + self.features_a = self.feature_model() + self.features_b = self.feature_model() + + self.fully_connected = nn.Sequential( + nn.Linear(2 * self.nc, self.nh), + nn.ReLU(), + nn.Linear(self.nh, 1) + ) + + def forward(self, a, b): + a = a.view(a.size(0), 1, a.size(1)) + a = self.features_a(a) + a = F.avg_pool1d(a, a.size(2)) + + b = b.view(b.size(0), 1, b.size(1)) + b = self.features_b(b) + b = F.avg_pool1d(b, b.size(2)) + + x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), 1) + return self.fully_connected(x) + +###################################################################### + if args.data == 'image_pair': create_pairs = create_image_pairs - model = NetImagePair() + model = NetForImagePair() + elif args.data == 'image_values_pair': create_pairs = create_image_values_pairs - model = NetImageValuesPair() + model = NetForImageValuesPair() + +elif args.data == 'sequence_pair': + create_pairs = create_sequences_pairs + model = NetForSequencePair() + + ## Save for figures + a, b, c = create_pairs() + for k in range(10): + file = open(f'train_{k:02d}.dat', 'w') + for i in range(a.size(1)): + file.write(f'{a[k, i]:f} {b[k,i]:f}\n') + file.close() + else: raise Exception('Unknown data ' + args.data) ###################################################################### +# Train -nb_epochs, batch_size = 50, 100 - -print('nb_parameters %d' % sum(x.numel() for x in model.parameters())) +print(f'nb_parameters {sum(x.numel() for x in model.parameters())}') -optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) - -if torch.cuda.is_available(): - model.cuda() +model.to(device) -for e in range(nb_epochs): +input_a, input_b, classes = create_pairs(train = True) - input_a, input_b, count = create_pairs(train = True) +for e in range(args.nb_epochs): - # The information bound is the entropy of the class distribution - class_proba = count.float() - class_proba /= class_proba.sum() - class_entropy = - (class_proba.log() * class_proba).sum().item() + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) input_br = input_b[torch.randperm(input_b.size(0))] acc_mi = 0.0 - for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), - input_b.split(batch_size), - input_br.split(batch_size)): + for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size), + input_b.split(args.batch_size), + input_br.split(args.batch_size)): mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() - loss = - mi acc_mi += mi.item() + loss = - mi optimizer.zero_grad() loss.backward() optimizer.step() - acc_mi /= (input_a.size(0) // batch_size) + acc_mi /= (input_a.size(0) // args.batch_size) - print('%d %.04f %.04f' % (e, acc_mi / math.log(2), class_entropy / math.log(2))) + print(f'{e+1} {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}') sys.stdout.flush() ###################################################################### +# Test -input_a, input_b, count = create_pairs(train = False) +input_a, input_b, classes = create_pairs(train = False) -for e in range(nb_epochs): - class_proba = count.float() - class_proba /= class_proba.sum() - class_entropy = - (class_proba.log() * class_proba).sum().item() +input_br = input_b[torch.randperm(input_b.size(0))] - input_br = input_b[torch.randperm(input_b.size(0))] +acc_mi = 0.0 - acc_mi = 0.0 - - for batch_a, batch_b, batch_br in zip(input_a.split(batch_size), - input_b.split(batch_size), - input_br.split(batch_size)): - mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() - acc_mi += mi.item() +for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size), + input_b.split(args.batch_size), + input_br.split(args.batch_size)): + mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log() + acc_mi += mi.item() - acc_mi /= (input_a.size(0) // batch_size) +acc_mi /= (input_a.size(0) // args.batch_size) -print('test %.04f %.04f'%(acc_mi / math.log(2), class_entropy / math.log(2))) +print(f'test {acc_mi / math.log(2):.04f} {entropy(classes) / math.log(2):.04f}') ######################################################################