Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jan 2024 21:41:17 +0000 (22:41 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jan 2024 21:41:17 +0000 (22:41 +0100)
blanket.py [new file with mode: 0755]
mygpt.py

diff --git a/blanket.py b/blanket.py
new file mode 100755 (executable)
index 0000000..2b9c896
--- /dev/null
@@ -0,0 +1,44 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+
+class Blanket(torch.autograd.Function):
+    @staticmethod
+    def normalize(x):
+        y = x.flatten(1)
+        y /= y.pow(2).sum(dim=1, keepdim=True).sqrt() + 1e-6
+        y *= math.sqrt(y.numel() / y.size(0))
+
+    @staticmethod
+    def forward(ctx, x):
+        x = x.clone()
+        # Normalize the forward
+        Blanket.normalize(x)
+        return x
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_output = grad_output.clone()
+        # Normalize the gradient
+        Blanket.normalize(grad_output)
+        return grad_output
+
+
+blanket = Blanket.apply
+
+######################################################################
+
+if __name__ == "__main__":
+    x = torch.rand(2, 3).requires_grad_()
+    y = blanket(x) * 10
+    print(y.pow(2).sum())
+    z = y.sin().sum()
+    g = torch.autograd.grad(z, x)[0]
+
+    print(g.pow(2).sum())
index 040845e..9a02bcd 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -21,6 +21,8 @@ from torch.nn import functional as F
 
 import ffutils
 
+from blanket import blanket
+
 # import memload
 
 ######################################################################
@@ -506,11 +508,11 @@ class Caterpillar(nn.Module):
 
         ######################################################################
 
-        self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0)
+        self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3)
         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
 
         self.w_K = randw(nb_heads, dim_qk, dim_model)
-        self.w_V = randw(nb_heads, dim_v, dim_model, factor=1)
+        self.w_V = randw(nb_heads, dim_v, dim_model)
         self.w_Q = randw(nb_heads, dim_qk, dim_model)
         self.w_O = randw(dim_v * nb_heads, dim_model)
 
@@ -567,6 +569,8 @@ class Caterpillar(nn.Module):
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
 
+        V, K = blanket(V), blanket(K)
+
         ######################################################################
         # Compute the recurrent state
 
@@ -583,19 +587,19 @@ class Caterpillar(nn.Module):
         # Clip the gating to avoid values greater than 1 when several
         # heads hit the same row
 
-        G = G / G.sum(1, keepdim=True).clamp(min=1)
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
 
-        H = (1 - G).log().sum(1, keepdim=True).exp()
+        # G_star = (1 - G).log().sum(1, keepdim=True).exp()
 
         ######################################################################
 
         def recurrence(G, V, K):
             # We prepare the arguments for the parallel scan
 
-            A = H
+            A = 1 - G.sum(dim=1)
 
-            gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V)
-            gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K)
+            gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+            gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
 
             # We start from cached values, which matters in inference
 
@@ -669,6 +673,8 @@ class Caterpillar(nn.Module):
 
         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
 
+        Q = blanket(Q)
+
         # We build tensors NxHxTxRxL where N is the sample index, H
         # the head, T the time, R the row in the caterpillar, and L
         # the column in the caterpillar
@@ -706,6 +712,8 @@ class Caterpillar(nn.Module):
 
         # Compute the final output
 
+        Y = blanket(Y)
+
         self.cache_Y[:, t0:t1] = Y @ self.w_O
 
         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
@@ -1081,31 +1089,35 @@ if __name__ == "__main__":
     # t = np.arange(dt, 20.0, dt)
     # ax.semilogx(t, np.exp(-t / 5.0))
     # ax.grid()
+    ax.set_yscale("log")
 
     ######################################################################
 
-    for label, model in [
-        # ("nn.Linear", linear),
-        ("mygpy.QKVAttention", qkv),
-        ("mygpt.Caterpillar", caterpillar),
+    for label, model, thickness in [
+        ("nn.Linear", linear, 0.2),
+        ("mygpy.QKVAttention", qkv, 1),
+        ("mygpt.Caterpillar", caterpillar, 2),
     ]:
         y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
 
-        data = []
-        for t in range(y.size(1)):
-            for d in torch.randperm(y.size(2))[:8]:
-                g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0]
-                sg = g.pow(2).sum().item()
-                # sg = 0
-                # for p in model.parameters():
-                # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
-                # sg = sg + g.pow(2).sum().item()
-                data.append([t, sg])
-
-        data = torch.tensor(data)
-        ax.scatter(
-            data[:, 0], data[:, 1], s=1, label=label
-        )  # , color='gray', label='Input')
+        for n, p in [("input", x)] + list(model.named_parameters()):
+            print(f"Processing {model}.{n}")
+            data = []
+            for t in range(y.size(1)):
+                sg = 0
+                for d in torch.randperm(y.size(2))[:8]:
+                    sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
+                assert not sg.isinf().any()
+                assert not sg.isnan().any()
+                data.append([t, sg.sum().item()])
+
+            data = torch.tensor(data)
+            # cx, cy = data[:, 0], data[:, 1]
+            cy = data[:, 1].sort().values
+            cx = torch.linspace(0, 1, cy.size(0))
+            ax.plot(
+                cx, cy, label=label + "." + n, linewidth=thickness
+            )  # , color='gray', label='Input')
 
     # ax.legend(frameon=False, loc="top right")