X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=sizer.py;h=5887e4afd241a6d03e74a44970c1ceab29e88df4;hp=dff36ebc6ba9f279e270339f8a6c548bb8140f86;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/sizer.py b/sizer.py index dff36eb..5887e4a 100755 --- a/sizer.py +++ b/sizer.py @@ -13,7 +13,9 @@ from torch import nn ###################################################################### if len(sys.argv) < 2: - print(sys.argv[0] + ''' + print( + sys.argv[0] + + """ For example: @@ -24,7 +26,8 @@ nn.Conv2d(32, 32, 3, padding = 1) nn.MaxPool2d(2) nn.Conv2d(32, 64, 3, padding = 1) nn.MaxPool2d(5) -nn.Conv2d(64, 64, (3, 4))''') +nn.Conv2d(64, 64, (3, 4))""" + ) exit(1) ###################################################################### @@ -36,15 +39,15 @@ while True: t = os.stat(sys.argv[1])[stat.ST_MTIME] if t > pt: pt = t - os.system('clear') + os.system("clear") try: - temp = [l.strip('\n\r') for l in open(sys.argv[1], 'r').readlines()] + temp = [l.strip("\n\r") for l in open(sys.argv[1], "r").readlines()] x = torch.zeros(eval(temp.pop(0))) - print('-> ' + str(tuple(x.size()))) + print("-> " + str(tuple(x.size()))) for k in temp: - print(' ' + k) - x = eval(k + '(x)') - print('-> ' + str(tuple(x.size()))) + print(" " + k) + x = eval(k + "(x)") + print("-> " + str(tuple(x.size()))) except: - print('** Error **') + print("** Error **") time.sleep(1)