27de2040b98c342cf48468b27d9a0e5c0b834ba3
[pytorch.git] / stack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import is_tensor
9
10 import sys
11
12 def exception_hook(exc_type, exc_value, tb):
13
14     tb = tb.tb_next
15
16     while tb:
17         #x=tb.tb_frame.f_code
18         # for field in dir(x):
19             # print(f'@@@ {field} {getattr(x, field)}')
20
21         filename = tb.tb_frame.f_code.co_filename
22         name = tb.tb_frame.f_code.co_name
23         line_no = tb.tb_lineno
24         print(f'  File "{filename}", line {line_no}, in {name}')
25         print(open(filename, 'r').readlines()[line_no-1], end='')
26
27         local_vars = tb.tb_frame.f_locals
28
29         for n,v in local_vars.items():
30             if is_tensor(v):
31                 print(f'  {n} -> {v.size()}:{v.dtype}:{v.device}')
32             else:
33                 print(f'  {n} -> {v}')
34
35         tb = tb.tb_next
36
37     print(f'{exc_type.__name__}: {exc_value}')
38
39 sys.excepthook = exception_hook
40
41 ######################################################################
42
43 if __name__ == '__main__':
44
45     import torch
46
47     def dummy(a,b):
48         print(a@b)
49
50     def blah(a,b):
51         c=b+b
52         dummy(a,c)
53
54     m=torch.randn(2,3)
55     x=torch.randn(3)
56     blah(m,x)
57     blah(x,m)
58