X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=mine_mnist.py;h=389544b780f7b5d051460d914a25285f4251cc5d;hp=0c485b20ad4689f6652bbf4da6078f789f6e8950;hb=4469498b31c1fb90cb2b1202dbaf86be0f2d18b0;hpb=99fab4ddc7ee5fedf7a898a9263e2c271ea7d721 diff --git a/mine_mnist.py b/mine_mnist.py index 0c485b2..389544b 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -47,6 +47,9 @@ parser.add_argument('--batch_size', type = int, default = 100, help = 'Batch size') +parser.add_argument('--independent', action = 'store_true', + help = 'Should the pair components be independent') + ###################################################################### def entropy(target): @@ -116,6 +119,9 @@ def create_image_pairs(train = False): c = torch.cat(uc, 0) perm = torch.randperm(a.size(0)) a = a[perm].contiguous() + + if args.independent: + perm = torch.randperm(a.size(0)) b = b[perm].contiguous() return a, b, c @@ -150,7 +156,11 @@ def create_image_values_pairs(train = False): b = a.new(a.size(0), 2) b[:, 0].uniform_(0.0, 10.0) b[:, 1].uniform_(0.0, 0.5) - b[:, 1] += b[:, 0] + target.float() + + if args.independent: + b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())] + else: + b[:, 1] += b[:, 0] + target.float() return a, b, c @@ -161,8 +171,10 @@ def create_sequences_pairs(train = False): noise_level = 2e-2 ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1 - # hb = torch.randint(args.nb_classes, (nb, ), device = device) - hb = ha + if args.independent: + hb = torch.randint(args.nb_classes, (nb, ), device = device) + else: + hb = ha pos = torch.empty(nb, device = device).uniform_(0.0, 0.9) a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1) @@ -314,7 +326,7 @@ elif args.data == 'sequence_pair': ###################################################################### a, b, c = create_pairs() for k in range(10): - file = open(f'/tmp/train_{k:02d}.dat', 'w') + file = open(f'train_{k:02d}.dat', 'w') for i in range(a.size(1)): file.write(f'{a[k, i]:f} {b[k,i]:f}\n') file.close()