From 4469498b31c1fb90cb2b1202dbaf86be0f2d18b0 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 18 Dec 2018 09:55:05 +0100 Subject: [PATCH] Update. --- confidence.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++ mine_mnist.py | 20 +++++++++++++++---- 2 files changed, 69 insertions(+), 4 deletions(-) create mode 100755 confidence.py diff --git a/confidence.py b/confidence.py new file mode 100755 index 0000000..ff4b395 --- /dev/null +++ b/confidence.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +nb = 100 +delta = 0.35 +x = torch.empty(nb).uniform_(0.0, delta) +x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta) + +a = x * math.pi * 2 * 4 +b = x * math.pi * 2 * 3 +y = a.sin() + b + +x = x.view(-1, 1) +y = y.view(-1, 1) + +###################################################################### + +nh = 100 + +model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(), + nn.Linear(nh, nh), nn.ReLU(), + nn.Linear(nh, 1)) + +criterion = nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) + +for k in range(10000): + loss = criterion(model(x), y) + if (k+1)%100 == 0: print(k+1, loss.item()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +###################################################################### + +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +ax.scatter(x.numpy(), y.numpy()) + +u = torch.linspace(0, 1, 100).view(-1, 1) +ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red') +plt.show() + +###################################################################### diff --git a/mine_mnist.py b/mine_mnist.py index 0c485b2..389544b 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -47,6 +47,9 @@ parser.add_argument('--batch_size', type = int, default = 100, help = 'Batch size') +parser.add_argument('--independent', action = 'store_true', + help = 'Should the pair components be independent') + ###################################################################### def entropy(target): @@ -116,6 +119,9 @@ def create_image_pairs(train = False): c = torch.cat(uc, 0) perm = torch.randperm(a.size(0)) a = a[perm].contiguous() + + if args.independent: + perm = torch.randperm(a.size(0)) b = b[perm].contiguous() return a, b, c @@ -150,7 +156,11 @@ def create_image_values_pairs(train = False): b = a.new(a.size(0), 2) b[:, 0].uniform_(0.0, 10.0) b[:, 1].uniform_(0.0, 0.5) - b[:, 1] += b[:, 0] + target.float() + + if args.independent: + b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())] + else: + b[:, 1] += b[:, 0] + target.float() return a, b, c @@ -161,8 +171,10 @@ def create_sequences_pairs(train = False): noise_level = 2e-2 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1 - # hb = torch.randint(args.nb_classes, (nb, ), device = device) - hb = ha + if args.independent: + hb = torch.randint(args.nb_classes, (nb, ), device = device) + else: + hb = ha pos = torch.empty(nb, device = device).uniform_(0.0, 0.9) a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) @@ -314,7 +326,7 @@ elif args.data == 'sequence_pair': ###################################################################### a, b, c = create_pairs() for k in range(10): - file = open(f'/tmp/train_{k:02d}.dat', 'w') + file = open(f'train_{k:02d}.dat', 'w') for i in range(a.size(1)): file.write(f'{a[k, i]:f} {b[k,i]:f}\n') file.close() -- 2.20.1