From ba45285b08782597aacd2764a7506b28a0fbf5d2 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sun, 6 Jun 2021 14:30:34 +0200 Subject: [PATCH] Update. --- conv_chain.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/conv_chain.py b/conv_chain.py index 04dfdfa..fa5d752 100755 --- a/conv_chain.py +++ b/conv_chain.py @@ -24,19 +24,21 @@ def conv_chain(input_size, output_size, depth, cond): ###################################################################### -# Example +if __name__ == "__main__": -c = conv_chain( - input_size = 64, output_size = 8, - depth = 5, - cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2 -) + # Example -x = torch.rand(1, 1, 64) + c = conv_chain( + input_size = 64, output_size = 8, + depth = 5, + cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2 + ) -for m in c: - m = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ]) - print(m) - print(x.size(), m(x).size()) + x = torch.rand(1, 1, 64) + + for m in c: + model = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ]) + print(model) + print(x.size(), model(x).size()) ###################################################################### -- 2.20.1