Update.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 21 Feb 2020 21:20:15 +0000 (22:20 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 21 Feb 2020 21:20:15 +0000 (22:20 +0100)
dummy.net [new file with mode: 0644]
sizer.py [new file with mode: 0755]

diff --git a/dummy.net b/dummy.net
new file mode 100644 (file)
index 0000000..31e3a1e
--- /dev/null
+++ b/dummy.net
@@ -0,0 +1,8 @@
+(17, 3, 60, 80)
+nn.Conv2d(3, 32, 3, padding = 1)
+nn.MaxPool2d(2)
+nn.Conv2d(32, 32, 3, padding = 1)
+nn.MaxPool2d(2)
+nn.Conv2d(32, 64, 3, padding = 1)
+nn.MaxPool2d(5)
+nn.Conv2d(64, 128, (3, 4))
diff --git a/sizer.py b/sizer.py
new file mode 100755 (executable)
index 0000000..52620e8
--- /dev/null
+++ b/sizer.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python
+
+import os, stat, sys
+import time
+import torch
+from torch import nn
+
+t = 0
+
+while True:
+    pt = t
+    t = os.stat(sys.argv[1])[stat.ST_MTIME]
+    if t > pt:
+        pt = t
+        os.system('clear')
+        try:
+            temp = [l.strip('\n\r') for l in open(sys.argv[1], 'r').readlines()]
+            x = torch.zeros(eval(temp.pop(0)))
+            print('-> ' + str(tuple(x.size())))
+            for k in temp:
+                print('   ' + k)
+                x = eval(k + '(x)')
+                print('-> ' + str(tuple(x.size())))
+        except:
+            print('** Error **')
+    time.sleep(1)