c = conv_chain(
input_size = 64, output_size = 8,
depth = 5,
- cond = lambda d, k, s: k <= 4 and s <= k and (s == 1 or d < 3)
+ # We want kernels smaller than 4, strides smaller than the
+ # kernels, and stride of 1 except in the two last layers
+ cond = lambda d, k, s: k <= 4 and s <= k and (s == 1 or d <= 2)
)
x = torch.rand(1, 1, 64)