X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=0bb0d145bf9c6c82115956c8ce1e6a063e56e747;hb=HEAD;hp=88cb3d555b77036e92783602a3f8d8a96ffea4b9;hpb=c0750e416e28fbdc9f6dc03cc6d7b11edd1ac333;p=mygptrnn.git diff --git a/pscan.py b/pscan.py index 88cb3d5..0bb0d14 100755 --- a/pscan.py +++ b/pscan.py @@ -122,6 +122,22 @@ def naive_pscan(A, X, Y_init): if __name__ == "__main__": import time, sys + ###################################################################### + + N, T, D = 16, 4096, 32 + + for r in range(timing.size(0)): + A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() + + start_time = time.perf_counter() + for _ in range(1000): + Y = pscan(A, X, Y_init) + duration = time.perf_counter() - start_time + + ###################################################################### + # A = torch.rand(17, 12, 3) # X = torch.rand(17, 12, 3, 11) # Y_init = torch.rand(17, 3, 11) @@ -130,11 +146,12 @@ if __name__ == "__main__": # exit(0) err = 0 + timing = torch.empty(10) - for _ in range(100): - N, T, D = 2, 112, 3 + for r in range(timing.size(0)): + N, T, D = 2, 1120, 3 - T = torch.randint(10, (1,)).item() + 1 + # T = torch.randint(10, (1,)).item() + 1 A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() @@ -161,7 +178,9 @@ if __name__ == "__main__": for _ in range(1000): Y = pscan(A, X, Y_init) duration = time.perf_counter() - start_time + print(f"duration {duration}") + timing[r] = duration s = Y.sum() @@ -181,4 +200,4 @@ if __name__ == "__main__": # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max()) -print(f"{err=}") + print(f"err={err:.2e} duration={timing.mean():.2e} (+/- {timing.std():.2e})")