Update.
[pytorch.git] / mine_mnist.py
index 0c485b2..389544b 100755 (executable)
@@ -47,6 +47,9 @@ parser.add_argument('--batch_size',
                     type = int, default = 100,
                     help = 'Batch size')
 
+parser.add_argument('--independent', action = 'store_true',
+                    help = 'Should the pair components be independent')
+
 ######################################################################
 
 def entropy(target):
@@ -116,6 +119,9 @@ def create_image_pairs(train = False):
     c = torch.cat(uc, 0)
     perm = torch.randperm(a.size(0))
     a = a[perm].contiguous()
+
+    if args.independent:
+        perm = torch.randperm(a.size(0))
     b = b[perm].contiguous()
 
     return a, b, c
@@ -150,7 +156,11 @@ def create_image_values_pairs(train = False):
     b = a.new(a.size(0), 2)
     b[:, 0].uniform_(0.0, 10.0)
     b[:, 1].uniform_(0.0, 0.5)
-    b[:, 1] += b[:, 0] + target.float()
+
+    if args.independent:
+        b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+    else:
+        b[:, 1] += b[:, 0] + target.float()
 
     return a, b, c
 
@@ -161,8 +171,10 @@ def create_sequences_pairs(train = False):
     noise_level = 2e-2
 
     ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
-    # hb = torch.randint(args.nb_classes, (nb, ), device = device)
-    hb = ha
+    if args.independent:
+        hb = torch.randint(args.nb_classes, (nb, ), device = device)
+    else:
+        hb = ha
 
     pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
     a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
@@ -314,7 +326,7 @@ elif args.data == 'sequence_pair':
     ######################################################################
     a, b, c = create_pairs()
     for k in range(10):
-        file = open(f'/tmp/train_{k:02d}.dat', 'w')
+        file = open(f'train_{k:02d}.dat', 'w')
         for i in range(a.size(1)):
             file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
         file.close()