Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 01:57:06 +0000 (02:57 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 18 Dec 2023 01:57:06 +0000 (02:57 +0100)
pscan.py

index 1dfb442..071f284 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -77,39 +77,31 @@ pscan = PScan.apply
 if __name__ == "__main__":
     N, T, D = 2, 5, 3
 
-    # Iterative implementation
-
     A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
     X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
     Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()
 
+    # Iterative implementation
+
     y = Y0
     s = 0
 
     for k in range(A.size(1)):
         y = A[:, k, None] * y + X[:, k]
         s = s + y
-        # print(f"{k} -> {y}")
 
     s = s.sum()
 
-    # print(s)
-    print(torch.autograd.grad(s, A, retain_graph=True))
-    print(torch.autograd.grad(s, X, retain_graph=True))
-    print(torch.autograd.grad(s, Y0, retain_graph=True))
-
-    print()
+    gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
 
     # parallel scan
 
     Y = pscan(A, X, Y0)
 
-    # for k in range(A.size(1)):
-    # print(f"{k} -> {Y[:,k]}")
-
     s = Y.sum()
 
-    # print(s)
-    print(torch.autograd.grad(s, A, retain_graph=True))
-    print(torch.autograd.grad(s, X, retain_graph=True))
-    print(torch.autograd.grad(s, Y0, retain_graph=True))
+    gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+
+    print((gA - gA_ref).norm())
+    print((gX - gX_ref).norm())
+    print((gY0 - gY0_ref).norm())