Now also catch ValueError.
[pytorch.git] / tensorstack.py
index 544306c..d7b0d89 100755 (executable)
@@ -11,8 +11,8 @@ import sys
 
 def exception_hook(exc_type, exc_value, tb):
     r'''Hacks the call stack message in case of RuntimeError to show all
-    the local variables, and indicate for every tensor its shape,
-    dtype and device.
+    the local variables, and indicate for every tensor involved its
+    shape, dtype and device.
 
     '''
 
@@ -20,17 +20,18 @@ def exception_hook(exc_type, exc_value, tb):
     Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}'
 
     while tb:
-        print('--------------------------------------------------')
+        print('--------------------------------------------------\n')
         filename = tb.tb_frame.f_code.co_filename
         name = tb.tb_frame.f_code.co_name
         line_no = tb.tb_lineno
         print(f'  File "{filename}", line {line_no}, in {name}')
-        print(open(filename, 'r').readlines()[line_no-1], end='')
+        print(open(filename, 'r').readlines()[line_no-1])
 
-        if exc_type is RuntimeError:
+        if exc_type in { RuntimeError, ValueError }:
             for n,v in tb.tb_frame.f_locals.items():
                 print(f'  {n} -> {v}')
 
+        print()
         tb = tb.tb_next
 
     Tensor.__repr__=repr_orig