projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
5ab8211
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Sun, 6 Jun 2021 12:30:34 +0000
(14:30 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sun, 6 Jun 2021 12:30:34 +0000
(14:30 +0200)
conv_chain.py
patch
|
blob
|
history
diff --git
a/conv_chain.py
b/conv_chain.py
index
04dfdfa
..
fa5d752
100755
(executable)
--- a/
conv_chain.py
+++ b/
conv_chain.py
@@
-24,19
+24,21
@@
def conv_chain(input_size, output_size, depth, cond):
######################################################################
######################################################################
-# Example
+if __name__ == "__main__":
-c = conv_chain(
- input_size = 64, output_size = 8,
- depth = 5,
- cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2
-)
+ # Example
-x = torch.rand(1, 1, 64)
+ c = conv_chain(
+ input_size = 64, output_size = 8,
+ depth = 5,
+ cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2
+ )
-for m in c:
- m = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ])
- print(m)
- print(x.size(), m(x).size())
+ x = torch.rand(1, 1, 64)
+
+ for m in c:
+ model = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ])
+ print(model)
+ print(x.size(), model(x).size())
######################################################################
######################################################################