+
+class ConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ ks, nc = 5, 64
+
+ self.core = nn.Sequential(
+ nn.Conv2d(in_channels, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, out_channels, ks, padding = ks//2),
+ )
+
+ def forward(self, x):
+ return self.core(x)
+
+######################################################################
+# Data