if __name__ == "__main__":
import time, sys
+ ######################################################################
+
+ N, T, D = 16, 4096, 32
+
+ for r in range(timing.size(0)):
+ A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+
+ start_time = time.perf_counter()
+ for _ in range(1000):
+ Y = pscan(A, X, Y_init)
+ duration = time.perf_counter() - start_time
+
+ ######################################################################
+
# A = torch.rand(17, 12, 3)
# X = torch.rand(17, 12, 3, 11)
# Y_init = torch.rand(17, 3, 11)
# exit(0)
err = 0
+ timing = torch.empty(10)
- for _ in range(100):
- N, T, D = 2, 112, 3
+ for r in range(timing.size(0)):
+ N, T, D = 2, 1120, 3
- T = torch.randint(10, (1,)).item() + 1
+ # T = torch.randint(10, (1,)).item() + 1
A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
for _ in range(1000):
Y = pscan(A, X, Y_init)
duration = time.perf_counter() - start_time
+
print(f"duration {duration}")
+ timing[r] = duration
s = Y.sum()
# print((Y - torch.cat([Y1, Y2], dim=1)).abs().max())
-print(f"{err=}")
+ print(f"err={err:.2e} duration={timing.mean():.2e} (+/- {timing.std():.2e})")