+ input = input[used_indices].contiguous()
+ target = target[used_indices].contiguous()
+
+ a = input
+
+ b = a.new(a.size(0), 2)
+ b[:, 0].uniform_(10)
+ b[:, 1].uniform_(0.5)
+ b[:, 1] += b[:, 0] + target.float()
+
+ c = torch.tensor([(target == k).sum().item() for k in used_MNIST_classes])
+
+ return a, b, c
+
+######################################################################
+
+class NetImagePair(nn.Module):
+ def __init__(self):
+ super(NetImagePair, self).__init__()
+ self.features_a = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.features_b = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.fully_connected = nn.Sequential(
+ nn.Linear(256, 200),
+ nn.ReLU(),
+ nn.Linear(200, 1)
+ )
+
+ def forward(self, a, b):
+ a = self.features_a(a).view(a.size(0), -1)
+ b = self.features_b(b).view(b.size(0), -1)
+ x = torch.cat((a, b), 1)
+ return self.fully_connected(x)
+
+######################################################################
+
+class NetImageValuesPair(nn.Module):
+ def __init__(self):
+ super(NetImageValuesPair, self).__init__()
+ self.features_a = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.features_b = nn.Sequential(
+ nn.Linear(2, 32), nn.ReLU(),
+ nn.Linear(32, 32), nn.ReLU(),
+ nn.Linear(32, 128), nn.ReLU(),
+ )
+
+ self.fully_connected = nn.Sequential(
+ nn.Linear(256, 200),
+ nn.ReLU(),
+ nn.Linear(200, 1)
+ )
+
+ def forward(self, a, b):
+ a = self.features_a(a).view(a.size(0), -1)
+ b = self.features_b(b).view(b.size(0), -1)
+ x = torch.cat((a, b), 1)
+ return self.fully_connected(x)
+
+######################################################################
+
+if args.data == 'image_pair':
+ create_pairs = create_image_pairs
+ model = NetImagePair()
+elif args.data == 'image_values_pair':
+ create_pairs = create_image_values_pairs
+ model = NetImageValuesPair()
+else:
+ raise Exception('Unknown data ' + args.data)
+
+######################################################################