+ return a, b, c
+
+######################################################################
+
+class NetImagePair(nn.Module):
+ def __init__(self):
+ super(NetImagePair, self).__init__()
+ self.features_a = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.features_b = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.fully_connected = nn.Sequential(
+ nn.Linear(256, 200),
+ nn.ReLU(),
+ nn.Linear(200, 1)
+ )
+
+ def forward(self, a, b):
+ a = self.features_a(a).view(a.size(0), -1)
+ b = self.features_b(b).view(b.size(0), -1)
+ x = torch.cat((a, b), 1)
+ return self.fully_connected(x)
+
+######################################################################
+
+class NetImageValuesPair(nn.Module):
+ def __init__(self):
+ super(NetImageValuesPair, self).__init__()
+ self.features_a = nn.Sequential(
+ nn.Conv2d(1, 16, kernel_size = 5),
+ nn.MaxPool2d(3), nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size = 5),
+ nn.MaxPool2d(2), nn.ReLU(),
+ )
+
+ self.features_b = nn.Sequential(
+ nn.Linear(2, 32), nn.ReLU(),
+ nn.Linear(32, 32), nn.ReLU(),
+ nn.Linear(32, 128), nn.ReLU(),
+ )
+
+ self.fully_connected = nn.Sequential(
+ nn.Linear(256, 200),
+ nn.ReLU(),
+ nn.Linear(200, 1)
+ )
+
+ def forward(self, a, b):
+ a = self.features_a(a).view(a.size(0), -1)
+ b = self.features_b(b).view(b.size(0), -1)
+ x = torch.cat((a, b), 1)
+ return self.fully_connected(x)
+
+######################################################################
+
+if args.data == 'image_pair':
+ create_pairs = create_image_pairs
+ model = NetImagePair()
+elif args.data == 'image_values_pair':
+ create_pairs = create_image_values_pairs
+ model = NetImageValuesPair()
+else:
+ raise Exception('Unknown data ' + args.data)
+
+######################################################################