Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 21 Jan 2024 22:41:09 +0000 (23:41 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 21 Jan 2024 22:41:09 +0000 (23:41 +0100)
fridge
mygpt.py

diff --git a/fridge b/fridge
index 2cc6d01..82d2b17 100644 (file)
--- a/fridge
+++ b/fridge
@@ -302,3 +302,17 @@ class Calibrator:
         # G = (
         # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         # ).softmax(dim=2)
+
+######################################################################
+
+2024 Jan 21 16:55:24 (from main.py)
+
+        with open("test.dat", "a") as f:
+            for m filter(lambda m: isinstance(m,mygpt.Catenn.Linear),model.modules()):
+                for p in m.parameters() ]
+
+
+        for m in model.modules():
+            if isinstance(m, mygpt.Caterpillar):
+                
+
index b137cdb..040845e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -493,10 +493,8 @@ class Caterpillar(nn.Module):
 
         warnings.warn("Caterpillar", RuntimeWarning)
 
-        def randw(*d, amplitude=None):
-            if amplitude is None:
-                amplitude = 1 / math.sqrt(d[-1])
-            return nn.Parameter(amplitude * torch.randn(*d))
+        def randw(*d, factor=1):
+            return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1]))
 
         self.caterpillar_length = caterpillar_length
         self.caterpillar_height = caterpillar_height
@@ -508,12 +506,11 @@ class Caterpillar(nn.Module):
 
         ######################################################################
 
-        default_bg = -math.log(caterpillar_height - 1)
-        self.w_G = randw(nb_heads, caterpillar_height, dim_model)
-        self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
+        self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0)
+        self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
 
         self.w_K = randw(nb_heads, dim_qk, dim_model)
-        self.w_V = randw(nb_heads, dim_v, dim_model)
+        self.w_V = randw(nb_heads, dim_v, dim_model, factor=1)
         self.w_Q = randw(nb_heads, dim_qk, dim_model)
         self.w_O = randw(dim_v * nb_heads, dim_model)
 
@@ -586,17 +583,19 @@ class Caterpillar(nn.Module):
         # Clip the gating to avoid values greater than 1 when several
         # heads hit the same row
 
-        G = G / G.sum(1, keepdim=True).clamp(min=1)
+        # G = G / G.sum(1, keepdim=True).clamp(min=1)
+
+        H = (1 - G).log().sum(1, keepdim=True).exp()
 
         ######################################################################
 
         def recurrence(G, V, K):
             # We prepare the arguments for the parallel scan
 
-            A = 1 - G.sum(1)
+            A = H
 
-            gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
-            gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
+            gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V)
+            gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K)
 
             # We start from cached values, which matters in inference
 
@@ -653,6 +652,8 @@ class Caterpillar(nn.Module):
                 next_V = next_V.detach()
                 next_K = next_K.detach()
 
+            warnings.warn("the rescaling is probably a bad idea", RuntimeWarning)
+
             next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
                 1 - self.gate_dropout_proba
             )
@@ -814,7 +815,7 @@ class MyGPT(nn.Module):
         causal=False,
         dropout=0.0,
         len_max=1e5,
-        attention_layer="kvrec",
+        attention_layer="caterpillar",
         logger=print,
         args=None,
     ):
@@ -1019,7 +1020,111 @@ class MyGPT(nn.Module):
 ######################################################################
 
 if __name__ == "__main__":
-    print("Basic check.")
+    import argparse
+
+    import numpy as np
+    import matplotlib.pyplot as plt
+    import matplotlib.collections as mc
+
+    args = argparse.Namespace(
+        gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False
+    )
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    dim_model, dim_keys, nb_heads = 512, 64, 1
+    dropout = 0.1
+
+    caterpillar = Caterpillar(
+        dim_model=dim_model,
+        dim_qk=dim_keys,
+        dim_v=dim_model // nb_heads,
+        nb_heads=nb_heads,
+        caterpillar_length=16,
+        caterpillar_height=32,
+        attention_dropout=dropout,
+        args=args,
+    ).to(device)
+
+    qkv = QKVAttention(
+        dim_model=dim_model,
+        dim_qk=dim_keys,
+        dim_v=dim_model // nb_heads,
+        nb_heads=nb_heads,
+        causal=True,
+        attention_dropout=dropout,
+        args=args,
+    ).to(device)
+
+    linear = CacheWrapper(nn.Linear(512, 512)).to(device)
+
+    x = torch.randn(1, 256, dim_model)
+
+    x = x.to(device)
+    x.requires_grad_()
+
+    ######################################################################
+
+    fig = plt.figure()
+    fig.set_figheight(6)
+    fig.set_figwidth(8)
+
+    ax = fig.add_subplot(1, 1, 1)
+
+    # ax.set_xlim(-1.5, 1.5)
+    # ax.set_ylim(-1.5, 1.5)
+    # ax.set(aspect=1)
+    # ax.spines.right.set_visible(False)
+    # ax.spines.top.set_visible(False)
+
+    # dt = 0.01
+    # t = np.arange(dt, 20.0, dt)
+    # ax.semilogx(t, np.exp(-t / 5.0))
+    # ax.grid()
+
+    ######################################################################
+
+    for label, model in [
+        # ("nn.Linear", linear),
+        ("mygpy.QKVAttention", qkv),
+        ("mygpt.Caterpillar", caterpillar),
+    ]:
+        y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
+
+        data = []
+        for t in range(y.size(1)):
+            for d in torch.randperm(y.size(2))[:8]:
+                g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0]
+                sg = g.pow(2).sum().item()
+                # sg = 0
+                # for p in model.parameters():
+                # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
+                # sg = sg + g.pow(2).sum().item()
+                data.append([t, sg])
+
+        data = torch.tensor(data)
+        ax.scatter(
+            data[:, 0], data[:, 1], s=1, label=label
+        )  # , color='gray', label='Input')
+
+    # ax.legend(frameon=False, loc="top right")
+
+    # Put a legend to the right of the current axis
+    box = ax.get_position()
+    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
+    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+
+    filename = "plot.pdf"
+    print(f"saving {filename}")
+    fig.savefig(filename, bbox_inches="tight")
+
+    # if args.window and hasattr(plt.get_current_fig_manager(), 'window'):
+    # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
+    # plt.show()
+
+    exit(0)
+
+    ######################################################################
 
     m = Caterpillar(
         dim_model=4,
@@ -1041,8 +1146,6 @@ if __name__ == "__main__":
     print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
     exit(0)
 
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
     vocabulary_size = 128
     x = torch.randint(vocabulary_size, (6, 1024))