0838bee9e025340f3058fedb7478ce14200ac788
[pytorch.git] / stack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import Tensor
9
10 import sys
11
12 def exception_hook(exc_type, exc_value, tb):
13
14     repr_orig=Tensor.__repr__
15     Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}'
16
17     while tb:
18         print('--------------------------------------------------')
19         filename = tb.tb_frame.f_code.co_filename
20         name = tb.tb_frame.f_code.co_name
21         line_no = tb.tb_lineno
22         print(f'  File "{filename}", line {line_no}, in {name}')
23         print(open(filename, 'r').readlines()[line_no-1], end='')
24
25         if exc_type is RuntimeError:
26             for n,v in tb.tb_frame.f_locals.items():
27                 print(f'  {n} -> {v}')
28
29         tb = tb.tb_next
30
31     Tensor.__repr__=repr_orig
32
33     print(f'{exc_type.__name__}: {exc_value}')
34
35 sys.excepthook = exception_hook
36
37 ######################################################################
38
39 if __name__ == '__main__':
40
41     import torch
42
43     def dummy(a,b):
44         print(a@b)
45
46     def blah(a,b):
47         c=b+b
48         dummy(a,c)
49
50     mmm=torch.randn(2,3)
51     xxx=torch.randn(3)
52     #print(xxx@mmm)
53     blah(mmm,xxx)
54     blah(xxx,mmm)