Update.
[mygptrnn.git] / pscan.py
index 88cb3d5..0bb0d14 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -122,6 +122,22 @@ def naive_pscan(A, X, Y_init):
 if __name__ == "__main__":
     import time, sys
 
+    ######################################################################
+
+    N, T, D = 16, 4096, 32
+
+    for r in range(timing.size(0)):
+        A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
+        X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+        Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+
+        start_time = time.perf_counter()
+        for _ in range(1000):
+            Y = pscan(A, X, Y_init)
+        duration = time.perf_counter() - start_time
+
+    ######################################################################
+
     # A = torch.rand(17, 12, 3)
     # X = torch.rand(17, 12, 3, 11)
     # Y_init = torch.rand(17, 3, 11)
@@ -130,11 +146,12 @@ if __name__ == "__main__":
     # exit(0)
 
     err = 0
+    timing = torch.empty(10)
 
-    for _ in range(100):
-        N, T, D = 2, 112, 3
+    for r in range(timing.size(0)):
+        N, T, D = 2, 1120, 3
 
-        T = torch.randint(10, (1,)).item() + 1
+        T = torch.randint(10, (1,)).item() + 1
 
         A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
         X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
@@ -161,7 +178,9 @@ if __name__ == "__main__":
         for _ in range(1000):
             Y = pscan(A, X, Y_init)
         duration = time.perf_counter() - start_time
+
         print(f"duration {duration}")
+        timing[r] = duration
 
         s = Y.sum()
 
@@ -181,4 +200,4 @@ if __name__ == "__main__":
 
         # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max())
 
-print(f"{err=}")
+    print(f"err={err:.2e} duration={timing.mean():.2e} (+/- {timing.std():.2e})")