6 ######################################################################
8 def conv_chain(input_size, output_size, depth, cond):
10 if input_size == output_size:
16 for kernel_size in range(1, input_size + 1):
17 for stride in range(1, input_size + 1):
18 if cond(depth, kernel_size, stride):
19 n = (input_size - kernel_size) // stride + 1
20 if (n - 1) * stride + kernel_size == input_size:
21 q = conv_chain(n, output_size, depth - 1, cond)
22 r += [ [ (kernel_size, stride) ] + u for u in q ]
25 ######################################################################
27 if __name__ == "__main__":
32 input_size = 64, output_size = 8,
34 cond = lambda d, k, s: k <= 4 and s <= k and (s == 1 or d < 3)
37 x = torch.rand(1, 1, 64)
40 model = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ])
42 print(x.size(), model(x).size())
44 ######################################################################