Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 00:47:34 +0000 (01:47 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 00:47:34 +0000 (01:47 +0100)
pscan.py

index f344200..7b6cfc0 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -4,63 +4,62 @@ import torch
 
 ######################################################################
 
-# Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
-# and O(log(T)) if not core-bounded, so that
-#
-# Y[:, 0] = Y0
-# Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
-#
-# can be computed as
-#
-# Y[:, t] = A[:, t] * Y0 + X[:, t]
-
-
-def expand(A, X):
-    if A.size(1) == 1:
-        return
-    T = 2 * (A.size(1) // 2)
-    Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
-    Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
-    Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
-    Aa[:, :, 1].mul_(Aa[:, :, 0])
-    expand(Aa[:, :, 1], Xa[:, :, 1])
-    Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
-    Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
-    if T < A.size(1):
-        X[:, -1].add_(A[:, -1].mul(X[:, -2]))
-        A[:, -1].mul_(A[:, -2])
-
-
-# Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
-
-
-def accrev(X):
-    if X.size(1) == 1:
-        return
-    T = 2 * (X.size(1) // 2)
-    Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
-    Xa[:, :, 0].add_(Xa[:, :, 1])
-    accrev(Xa[:, :, 0])
-    Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
-    if T < X.size(1):
-        X[:, 0].add_(X[:, 1])
-
 
 class PScan(torch.autograd.Function):
+    # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+    # and O(log(T)) if not core-bounded, so that
+    #
+    # Y[:, 0] = Y0
+    # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
+    #
+    # can be computed as
+    #
+    # Y[:, t] = A[:, t] * Y0 + X[:, t]
+
+    @staticmethod
+    def expand(A, X):
+        if A.size(1) == 1:
+            return
+        T = 2 * (A.size(1) // 2)
+        Aa = A[:, :T].view(A.size(0), T // 2, 2, -1)
+        Xa = X[:, :T].view(X.size(0), T // 2, 2, -1)
+        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+        Aa[:, :, 1].mul_(Aa[:, :, 0])
+        PScan.expand(Aa[:, :, 1], Xa[:, :, 1])
+        Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+        Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+        if T < A.size(1):
+            X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+            A[:, -1].mul_(A[:, -2])
+
+    # Computes inplace Y[:, s] = \sum_{t >= s} X[:, t]
+
+    @staticmethod
+    def accrev(X):
+        if X.size(1) == 1:
+            return
+        T = 2 * (X.size(1) // 2)
+        Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1)
+        Xa[:, :, 0].add_(Xa[:, :, 1])
+        PScan.accrev(Xa[:, :, 0])
+        Xa[:, :-1, 1].add_(Xa[:, 1:, 0])
+        if T < X.size(1):
+            X[:, 0].add_(X[:, 1])
+
     @staticmethod
     def forward(ctx, A, X, Y0):
         ctx.A = A[:, :, None].clone()
         ctx.Y0 = Y0[:, None, :].clone()
         ctx.A_star = A[:, :, None].clone()
         ctx.X_star = X.clone()
-        expand(ctx.A_star, ctx.X_star)
+        PScan.expand(ctx.A_star, ctx.X_star)
         return ctx.A_star * ctx.Y0 + ctx.X_star
 
     @staticmethod
     def backward(ctx, grad_output):
         U = grad_output * ctx.A_star
         R = U.clone()
-        accrev(R)
+        PScan.accrev(R)
         Q = ctx.Y0 / ctx.A
         Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:])
         return (Q * R).sum(-1), R / ctx.A_star, U
@@ -71,6 +70,8 @@ pscan = PScan.apply
 ######################################################################
 
 if __name__ == "__main__":
+    # Iterative implementation
+
     A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
     X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
     Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
@@ -85,10 +86,12 @@ if __name__ == "__main__":
     print(torch.autograd.grad(y.mean(), X, retain_graph=True))
     print(torch.autograd.grad(y.mean(), Y0, retain_graph=True))
 
-    Y = pscan(A, X, Y0)
-
     print()
 
+    # parallel scan
+
+    Y = pscan(A, X, Y0)
+
     for k in range(A.size(1)):
         print(f"{k} -> {Y[:,k]}")