type = int, default = 100,
help = 'Batch size')
+parser.add_argument('--independent', action = 'store_true',
+ help = 'Should the pair components be independent')
+
######################################################################
def entropy(target):
c = torch.cat(uc, 0)
perm = torch.randperm(a.size(0))
a = a[perm].contiguous()
+
+ if args.independent:
+ perm = torch.randperm(a.size(0))
b = b[perm].contiguous()
return a, b, c
b = a.new(a.size(0), 2)
b[:, 0].uniform_(0.0, 10.0)
b[:, 1].uniform_(0.0, 0.5)
- b[:, 1] += b[:, 0] + target.float()
+
+ if args.independent:
+ b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+ else:
+ b[:, 1] += b[:, 0] + target.float()
return a, b, c
noise_level = 2e-2
ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
- # hb = torch.randint(args.nb_classes, (nb, ), device = device)
- hb = ha
+ if args.independent:
+ hb = torch.randint(args.nb_classes, (nb, ), device = device)
+ else:
+ hb = ha
pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
######################################################################
a, b, c = create_pairs()
for k in range(10):
- file = open(f'/tmp/train_{k:02d}.dat', 'w')
+ file = open(f'train_{k:02d}.dat', 'w')
for i in range(a.size(1)):
file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
file.close()