X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=tensorstack.py;h=42a051e09c53c03316aac9f1c0fa6dc092fd5b8d;hp=c9a6c2f301ad50b8ea75ff05b01dcf89b8578db9;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/tensorstack.py b/tensorstack.py index c9a6c2f..42a051e 100755 --- a/tensorstack.py +++ b/tensorstack.py @@ -9,52 +9,53 @@ from torch import Tensor import sys + def exception_hook(exc_type, exc_value, tb): - r'''Hacks the call stack message to show all the local variables in - case of RuntimeError or ValueError, and prints tensors as shape, - dtype and device. + r"""Hacks the call stack message to show all the local variables + in case of RuntimeError, ValueError, or TypeError and prints + tensors as shape, dtype and device. - ''' + """ - repr_orig=Tensor.__repr__ - Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}' + repr_orig = Tensor.__repr__ + Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" while tb: - print('--------------------------------------------------\n') + print("--------------------------------------------------\n") filename = tb.tb_frame.f_code.co_filename name = tb.tb_frame.f_code.co_name line_no = tb.tb_lineno print(f' File "{filename}", line {line_no}, in {name}') - print(open(filename, 'r').readlines()[line_no-1]) + print(open(filename, "r").readlines()[line_no - 1]) - if exc_type in { RuntimeError, ValueError }: - for n,v in tb.tb_frame.f_locals.items(): - print(f' {n} -> {v}') + if exc_type in {RuntimeError, ValueError, TypeError}: + for n, v in tb.tb_frame.f_locals.items(): + print(f" {n} -> {v}") print() tb = tb.tb_next - Tensor.__repr__=repr_orig + Tensor.__repr__ = repr_orig + + print(f"{exc_type.__name__}: {exc_value}") - print(f'{exc_type.__name__}: {exc_value}') sys.excepthook = exception_hook ###################################################################### -if __name__ == '__main__': - +if __name__ == "__main__": import torch - def dummy(a,b): - print(a@b) + def dummy(a, b): + print(a @ b) - def blah(a,b): - c=b+b - dummy(a,c) + def blah(a, b): + c = b + b + dummy(a, c) - mmm=torch.randn(2,3) - xxx=torch.randn(3) - #print(xxx@mmm) - blah(mmm,xxx) - blah(xxx,mmm) + mmm = torch.randn(2, 3) + xxx = torch.randn(3) + # print(xxx@mmm) + blah(mmm, xxx) + blah(xxx, mmm)