Update.
[mygptrnn.git] / ffutils.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch
9 import sys, contextlib
10
11 import torch
12 from torch import Tensor
13
14 ######################################################################
15
16
17 @contextlib.contextmanager
18 def evaluation(*models):
19     with torch.inference_mode():
20         t = [(m, m.training) for m in models]
21         for m in models:
22             m.train(False)
23         yield
24         for m, u in t:
25             m.train(u)
26
27
28 ######################################################################
29
30 from torch.utils._python_dispatch import TorchDispatchMode
31
32
33 def hasNaN(x):
34     if torch.is_tensor(x):
35         return x.numel() > 0 and x.isnan().max()
36     else:
37         try:
38             return any([hasNaN(y) for y in x])
39         except TypeError:
40             return False
41
42
43 class NaNDetect(TorchDispatchMode):
44     def __torch_dispatch__(self, func, types, args, kwargs=None):
45         kwargs = kwargs or {}
46         res = func(*args, **kwargs)
47
48         if hasNaN(res):
49             raise RuntimeError(
50                 f"Function {func}(*{args}, **{kwargs}) " "returned a NaN"
51             )
52         return res
53
54
55 ######################################################################
56
57
58 def exception_hook(exc_type, exc_value, tb):
59     r"""Hacks the call stack message to show all the local variables
60     in case of relevant error, and prints tensors as shape, dtype and
61     device.
62
63     """
64
65     repr_orig = Tensor.__repr__
66     Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
67
68     while tb:
69         print("--------------------------------------------------\n")
70         filename = tb.tb_frame.f_code.co_filename
71         name = tb.tb_frame.f_code.co_name
72         line_no = tb.tb_lineno
73         print(f'  File "{filename}", line {line_no}, in {name}')
74         print(open(filename, "r").readlines()[line_no - 1])
75
76         if exc_type in {RuntimeError, ValueError, IndexError, TypeError}:
77             for n, v in tb.tb_frame.f_locals.items():
78                 print(f"  {n} -> {v}")
79
80         print()
81         tb = tb.tb_next
82
83     Tensor.__repr__ = repr_orig
84
85     print(f"{exc_type.__name__}: {exc_value}")
86
87
88 def activate_tensorstack():
89     sys.excepthook = exception_hook
90
91
92 ######################################################################
93
94 if __name__ == "__main__":
95     import torch
96
97     def dummy(a, b):
98         print(a @ b)
99
100     def blah(a, b):
101         c = b + b
102         dummy(a, c)
103
104     mmm = torch.randn(2, 3)
105     xxx = torch.randn(3)
106     # print(xxx@mmm)
107     blah(mmm, xxx)
108     blah(xxx, mmm)