Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 20:51:28 +0000 (22:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 20:51:28 +0000 (22:51 +0200)
ffutils.py [new file with mode: 0755]
main.py
tasks.py
world.py

diff --git a/ffutils.py b/ffutils.py
new file mode 100755 (executable)
index 0000000..45f44d8
--- /dev/null
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch
+import sys, contextlib
+
+import torch
+from torch import Tensor
+
+######################################################################
+
+
+@contextlib.contextmanager
+def evaluation(*models):
+    with torch.inference_mode():
+        t = [(m, m.training) for m in models]
+        for m in models:
+            m.train(False)
+        yield
+        for m, u in t:
+            m.train(u)
+
+
+######################################################################
+
+from torch.utils._python_dispatch import TorchDispatchMode
+
+
+def hasNaN(x):
+    if torch.is_tensor(x):
+        return x.isnan().max()
+    else:
+        try:
+            return any([hasNaN(y) for y in x])
+        except TypeError:
+            return False
+
+
+class NaNDetect(TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args, kwargs=None):
+        kwargs = kwargs or {}
+        res = func(*args, **kwargs)
+
+        if hasNaN(res):
+            raise RuntimeError(
+                f"Function {func}(*{args}, **{kwargs}) " "returned a NaN"
+            )
+        return res
+
+
+######################################################################
+
+
+def exception_hook(exc_type, exc_value, tb):
+    r"""Hacks the call stack message to show all the local variables
+    in case of relevant error, and prints tensors as shape, dtype and
+    device.
+
+    """
+
+    repr_orig = Tensor.__repr__
+    Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
+
+    while tb:
+        print("--------------------------------------------------\n")
+        filename = tb.tb_frame.f_code.co_filename
+        name = tb.tb_frame.f_code.co_name
+        line_no = tb.tb_lineno
+        print(f'  File "{filename}", line {line_no}, in {name}')
+        print(open(filename, "r").readlines()[line_no - 1])
+
+        if exc_type in {RuntimeError, ValueError, IndexError, TypeError}:
+            for n, v in tb.tb_frame.f_locals.items():
+                print(f"  {n} -> {v}")
+
+        print()
+        tb = tb.tb_next
+
+    Tensor.__repr__ = repr_orig
+
+    print(f"{exc_type.__name__}: {exc_value}")
+
+
+def activate_tensorstack():
+    sys.excepthook = exception_hook
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import torch
+
+    def dummy(a, b):
+        print(a @ b)
+
+    def blah(a, b):
+        c = b + b
+        dummy(a, c)
+
+    mmm = torch.randn(2, 3)
+    xxx = torch.randn(3)
+    # print(xxx@mmm)
+    blah(mmm, xxx)
+    blah(xxx, mmm)
diff --git a/main.py b/main.py
index 69ee58f..e18887b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -14,6 +14,7 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
+import ffutils
 import mygpt, tasks
 
 ######################################################################
index 5583fc8..9cd06ae 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -75,11 +75,12 @@ class ProblemByheart(Problem):
     def __init__(self):
         nb_seq, len_prompt, len_result = 100, 5, 5
         self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
-        self.seq[:,len_prompt]=-1
+        self.seq[:, len_prompt] = -1
 
     def generate_sequences(self, nb):
         return self.seq[torch.randint(self.seq.size(0), (nb,))]
 
+
 class SandBox(Task):
     def __init__(
         self,
@@ -93,7 +94,7 @@ class SandBox(Task):
 
         self.batch_size = batch_size
 
-        problems = [ ProblemByheart() ]
+        problems = [ProblemByheart()]
         nb_common_codes = 100
 
         def generate_sequences(nb_samples):
@@ -101,7 +102,7 @@ class SandBox(Task):
             nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
             print(f"{nb_samples_per_problem}")
             all_seq = []
-            for nb, p in zip(nb_samples_per_problem,problems):
+            for nb, p in zip(nb_samples_per_problem, problems):
                 all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
             return all_seq
 
@@ -109,7 +110,7 @@ class SandBox(Task):
         test_seq = generate_sequences(nb_test_samples)
 
         for strain, stest in zip(train_seq, test_seq):
-            s = torch.cat((strain,stest),0)
+            s = torch.cat((strain, stest), 0)
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
index 3d6abbe..b35a08e 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -85,9 +85,9 @@ def loss_H(binary_logits, h_threshold=1):
 def train_encoder(
     train_input,
     test_input,
-    depth=2,
+    depth,
+    nb_bits_per_token,
     dim_hidden=48,
-    nb_bits_per_token=8,
     lambda_entropy=0.0,
     lr_start=1e-3,
     lr_end=1e-4,
@@ -366,6 +366,8 @@ def create_data_and_processors(
     nb_test_samples,
     mode,
     nb_steps,
+    depth=3,
+    nb_bits_per_token=8,
     nb_epochs=10,
     device=torch.device("cpu"),
     device_storage=torch.device("cpu"),
@@ -388,6 +390,8 @@ def create_data_and_processors(
     encoder, quantizer, decoder = train_encoder(
         train_input,
         test_input,
+        depth=depth,
+        nb_bits_per_token=nb_bits_per_token,
         lambda_entropy=1.0,
         nb_epochs=nb_epochs,
         logger=logger,