X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=sizer.py;h=cc0a19e26d19f5157de277cda7e8b0cea80f1688;hp=52620e88aac6600f5d5e80400a7162abc84991b2;hb=f99e2c83638c960d158c17270c072876834df9a9;hpb=7080726691be9341436db5a664778679600c5f62 diff --git a/sizer.py b/sizer.py index 52620e8..cc0a19e 100755 --- a/sizer.py +++ b/sizer.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import os, stat, sys import time import torch @@ -7,6 +12,21 @@ from torch import nn t = 0 +if len(sys.argv) < 2: + print(sys.argv[0] + ''' + +For example: + +(17, 3, 60, 80) +nn.Conv2d(3, 32, 3, padding = 1) +nn.MaxPool2d(2) +nn.Conv2d(32, 32, 3, padding = 1) +nn.MaxPool2d(2) +nn.Conv2d(32, 64, 3, padding = 1) +nn.MaxPool2d(5) +nn.Conv2d(64, 64, (3, 4))''') + exit(1) + while True: pt = t t = os.stat(sys.argv[1])[stat.ST_MTIME]