X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=ae_size.py;fp=ae_size.py;h=7bef9f507f8eb6ff15bbd81b40e0ad4ee2e2d927;hb=2db4624955ad2a1c29f7632f30ac217c045638cf;hp=0000000000000000000000000000000000000000;hpb=7b534d513dd1bdad208f2b59baf2d47979f90663;p=pytorch.git diff --git a/ae_size.py b/ae_size.py new file mode 100755 index 0000000..7bef9f5 --- /dev/null +++ b/ae_size.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +import math +from torch import nn +from torch import Tensor + +###################################################################### + +def minimal_input_size(w, layer_specs): + assert w > 0, 'The input is too small' + if layer_specs == []: + return w + else: + k, s = layer_specs[0] + w = math.ceil((w - k) / s) + 1 + w = minimal_input_size(w, layer_specs[1:]) + return int((w - 1) * s + k) + +###################################################################### + +layer_specs = [ (11, 5), (5, 2), (3, 2), (3, 2) ] + +layers = [] +for l in layer_specs: + layers.append(nn.Conv2d(1, 1, l[0], l[1])) + +for l in reversed(layer_specs): + layers.append(nn.ConvTranspose2d(1, 1, l[0], l[1])) + +m = nn.Sequential(*layers) + +h = minimal_input_size(240, layer_specs) +w = minimal_input_size(320, layer_specs) + +x = Tensor(1, 1, h, w).normal_() + +print(x.size(), m(x).size())