Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 18:07:35 +0000 (19:07 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 18:07:35 +0000 (19:07 +0100)
mygpt.py
pscan.py

index c061eb4..4d48247 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -563,7 +563,7 @@ class Caterpillar(nn.Module):
         # by updating that at time t-L, the parallel scan operates
         # with a period of L. To do so we split the time indexing in
         # two axes, the second of size CL, and run the parallel scan
-        # using the other alone as the sequence index.
+        # using the other as the sequence index.
 
         A = A.unflatten(2, (-1, CL))
         gated_V = gated_V.unflatten(2, (-1, CL))
index 0ec7b13..88cb3d5 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -11,8 +11,8 @@ import torch
 
 
 class PScan(torch.autograd.Function):
-    # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
-    # and O(log(T)) if not core-bounded, so that
+    # Given A is NxTxMx1 and X is NxTxMxD, expands A and X in
+    # place in O(T), and O(log(T)) if not core-bounded, so that
     #
     # Y[:, 0] = Y_init
     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
@@ -23,33 +23,57 @@ class PScan(torch.autograd.Function):
 
     @staticmethod
     def expand_(A, X):
-        if A.size(1) == 1:
-            return
-        T = 2 * (A.size(1) // 2)
-        Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
-        Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
-        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
-        Aa[:, :, 1].mul_(Aa[:, :, 0])
-        PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
-        Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
-        Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
-        if T < A.size(1):
-            X[:, -1].add_(A[:, -1].mul(X[:, -2]))
-            A[:, -1].mul_(A[:, -2])
+        # Unrolling gains ~8% speed
+
+        if A.size(1) > 4:
+            T = 2 * (A.size(1) // 2)
+            Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
+            Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
+            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+            Aa[:, :, 1].mul_(Aa[:, :, 0])
+            PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
+            Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+            Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+            if T < A.size(1):
+                X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+                A[:, -1].mul_(A[:, -2])
+        elif A.size(1) == 2:
+            X[:, 1].add_(A[:, 1].mul(X[:, 0]))
+            A[:, 1].mul_(A[:, 0])
+        elif A.size(1) == 3:
+            X[:, 1].add_(A[:, 1].mul(X[:, 0]))
+            A[:, 1].mul_(A[:, 0])
+            X[:, 2].add_(A[:, 2].mul(X[:, 1]))
+            A[:, 2].mul_(A[:, 1])
+        elif A.size(1) == 4:
+            X[:, 1].add_(A[:, 1].mul(X[:, 0]))
+            A[:, 1].mul_(A[:, 0])
+            X[:, 2].add_(A[:, 2].mul(X[:, 1]))
+            A[:, 2].mul_(A[:, 1])
+            X[:, 3].add_(A[:, 3].mul(X[:, 2]))
+            A[:, 3].mul_(A[:, 2])
 
     @staticmethod
     def acc_rev_(A, X):
-        if X.size(1) == 1:
-            return
-        T = 2 * (X.size(1) // 2)
-        Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
-        Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
-        Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
-        B = Aa[:, :, 0].clone()
-        B[:, 1:].mul_(Aa[:, :-1, 1])
-        PScan.acc_rev_(B, Xa[:, :, 0])
-        Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
-        if T < A.size(1):
+        if A.size(1) > 4:
+            T = 2 * (X.size(1) // 2)
+            Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
+            Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
+            Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
+            B = Aa[:, :, 0].clone()
+            B[:, 1:].mul_(Aa[:, :-1, 1])
+            PScan.acc_rev_(B, Xa[:, :, 0])
+            Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
+            if T < A.size(1):
+                X[:, 0].add_(A[:, 1].mul(X[:, 1]))
+        elif A.size(1) == 2:
+            X[:, 0].add_(A[:, 1].mul(X[:, 1]))
+        elif A.size(1) == 3:
+            X[:, 1].add_(A[:, 2].mul(X[:, 2]))
+            X[:, 0].add_(A[:, 1].mul(X[:, 1]))
+        elif A.size(1) == 4:
+            X[:, 2].add_(A[:, 3].mul(X[:, 3]))
+            X[:, 1].add_(A[:, 2].mul(X[:, 2]))
             X[:, 0].add_(A[:, 1].mul(X[:, 1]))
 
     # A is NxT, X is NxTxD, Y_init is NxD
@@ -81,59 +105,80 @@ class PScan(torch.autograd.Function):
 
 pscan = PScan.apply
 
+
+def naive_pscan(A, X, Y_init):
+    y = Y_init
+    s = 0
+
+    for k in range(A.size(1)):
+        y = A[:, k, None] * y + X[:, k]
+        s = s + y
+
+    s = s.sum()
+
+
 ######################################################################
 
 if __name__ == "__main__":
     import time, sys
 
-    A = torch.rand(17, 12, 3)
-    X = torch.rand(17, 12, 3, 11)
-    Y_init = torch.rand(17, 3, 11)
-    Y = pscan(A, X, Y_init)
-    exit(0)
+    # A = torch.rand(17, 12, 3)
+    # X = torch.rand(17, 12, 3, 11)
+    # Y_init = torch.rand(17, 3, 11)
+    # Y = pscan(A, X, Y_init)
 
-    N, T, D = 2, 1047, 3
+    # exit(0)
 
-    A = torch.rand(N, T, dtype=torch.float64).requires_grad_()
-    X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
-    Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+    err = 0
 
-    # Iterative implementation
+    for _ in range(100):
+        N, T, D = 2, 112, 3
 
-    y = Y_init
-    s = 0
+        T = torch.randint(10, (1,)).item() + 1
 
-    for k in range(A.size(1)):
-        y = A[:, k, None] * y + X[:, k]
-        s = s + y
+        A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
+        X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+        Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
 
-    s = s.sum()
+        # Iterative implementation
+
+        y = Y_init
+        s = 0
+
+        for k in range(A.size(1)):
+            y = A[:, k, None] * y + X[:, k]
+            s = s + y
+
+        s = s.sum()
 
-    gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
-        s, (A, X, Y_init), retain_graph=True
-    )
+        gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+            s, (A, X, Y_init), retain_graph=True
+        )
 
-    # parallel scan
+        # parallel scan
 
-    start_time = time.perf_counter()
-    for _ in range(1000):
-        Y = pscan(A, X, Y_init)
-    duration = time.perf_counter() - start_time
-    print(f"duration {duration}")
+        start_time = time.perf_counter()
+        for _ in range(1000):
+            Y = pscan(A, X, Y_init)
+        duration = time.perf_counter() - start_time
+        print(f"duration {duration}")
 
-    s = Y.sum()
+        s = Y.sum()
 
-    gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
+        gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
 
-    # print(gA)
-    # print(gX)
-    # print(gY_init)
+        err = max(
+            [
+                err,
+                (gA - gA_ref).abs().max(),
+                (gX - gX_ref).abs().max(),
+                (gY_init - gY_init_ref).abs().max(),
+            ]
+        )
 
-    print((gA - gA_ref).norm())
-    print((gX - gX_ref).norm())
-    print((gY_init - gY_init_ref).norm())
+        # Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+        # Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
 
-    Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
-    Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+        # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max())
 
-    print((Y - torch.cat([Y1, Y2], dim=1)).norm())
+print(f"{err=}")