c9a6c2f301ad50b8ea75ff05b01dcf89b8578db9
[pytorch.git] / tensorstack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import Tensor
9
10 import sys
11
12 def exception_hook(exc_type, exc_value, tb):
13     r'''Hacks the call stack message to show all the local variables in
14     case of RuntimeError or ValueError, and prints tensors as shape,
15     dtype and device.
16
17     '''
18
19     repr_orig=Tensor.__repr__
20     Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}'
21
22     while tb:
23         print('--------------------------------------------------\n')
24         filename = tb.tb_frame.f_code.co_filename
25         name = tb.tb_frame.f_code.co_name
26         line_no = tb.tb_lineno
27         print(f'  File "{filename}", line {line_no}, in {name}')
28         print(open(filename, 'r').readlines()[line_no-1])
29
30         if exc_type in { RuntimeError, ValueError }:
31             for n,v in tb.tb_frame.f_locals.items():
32                 print(f'  {n} -> {v}')
33
34         print()
35         tb = tb.tb_next
36
37     Tensor.__repr__=repr_orig
38
39     print(f'{exc_type.__name__}: {exc_value}')
40
41 sys.excepthook = exception_hook
42
43 ######################################################################
44
45 if __name__ == '__main__':
46
47     import torch
48
49     def dummy(a,b):
50         print(a@b)
51
52     def blah(a,b):
53         c=b+b
54         dummy(a,c)
55
56     mmm=torch.randn(2,3)
57     xxx=torch.randn(3)
58     #print(xxx@mmm)
59     blah(mmm,xxx)
60     blah(xxx,mmm)