Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 6 Jul 2023 09:31:37 +0000 (11:31 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 6 Jul 2023 09:31:37 +0000 (11:31 +0200)
main.py
tensorstack.py [deleted file]

diff --git a/main.py b/main.py
index 5b49468..df3f154 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -14,7 +14,7 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
-import mygpt, tasks, tensorstack
+import mygpt, tasks
 
 ######################################################################
 
@@ -384,16 +384,16 @@ train_set_perplexity = math.exp(entropy)
 train_examples = {}
 
 for input in task.batches(split="train"):
-    assert input.dim()==2 and input.dtype==torch.int64
+    assert input.dim() == 2 and input.dtype == torch.int64
     for x in input:
-        train_examples[x.sum().item()]=x
+        train_examples[x.sum().item()] = x
 
 for input in task.batches(split="test"):
-    assert input.dim()==2 and input.dtype==torch.int64
+    assert input.dim() == 2 and input.dtype == torch.int64
     for x in input:
         y = train_examples.get(x.sum().item())
         if y is not None:
-            assert x.size() != y.size() or (x-y).abs().sum() > 0
+            assert x.size() != y.size() or (x - y).abs().sum() > 0
 
 del train_examples
 
diff --git a/tensorstack.py b/tensorstack.py
deleted file mode 100755 (executable)
index 584c12d..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-from torch import Tensor
-
-import sys
-
-
-def exception_hook(exc_type, exc_value, tb):
-    r"""Hacks the call stack message to show all the local variables in
-    case of RuntimeError or ValueError, and prints tensors as shape,
-    dtype and device.
-
-    """
-
-    repr_orig = Tensor.__repr__
-    Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
-
-    while tb:
-        print("--------------------------------------------------\n")
-        filename = tb.tb_frame.f_code.co_filename
-        name = tb.tb_frame.f_code.co_name
-        line_no = tb.tb_lineno
-        print(f'  File "{filename}", line {line_no}, in {name}')
-        print(open(filename, "r").readlines()[line_no - 1])
-
-        if exc_type in {RuntimeError, ValueError}:
-            for n, v in tb.tb_frame.f_locals.items():
-                print(f"  {n} -> {v}")
-
-        print()
-        tb = tb.tb_next
-
-    Tensor.__repr__ = repr_orig
-
-    print(f"{exc_type.__name__}: {exc_value}")
-
-
-sys.excepthook = exception_hook
-
-######################################################################
-
-if __name__ == "__main__":
-    import torch
-
-    def dummy(a, b):
-        print(a @ b)
-
-    def blah(a, b):
-        c = b + b
-        dummy(a, c)
-
-    mmm = torch.randn(2, 3)
-    xxx = torch.randn(3)
-    # print(xxx@mmm)
-    blah(mmm, xxx)
-    blah(xxx, mmm)