From 59513fa775776af525477f01925f563ddffecdb2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 18 Dec 2023 02:57:06 +0100 Subject: [PATCH] Update. --- pscan.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/pscan.py b/pscan.py index 1dfb442..071f284 100755 --- a/pscan.py +++ b/pscan.py @@ -77,39 +77,31 @@ pscan = PScan.apply if __name__ == "__main__": N, T, D = 2, 5, 3 - # Iterative implementation - A = torch.randn(N, T, dtype=torch.float64).requires_grad_() X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() + # Iterative implementation + y = Y0 s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] s = s + y - # print(f"{k} -> {y}") s = s.sum() - # print(s) - print(torch.autograd.grad(s, A, retain_graph=True)) - print(torch.autograd.grad(s, X, retain_graph=True)) - print(torch.autograd.grad(s, Y0, retain_graph=True)) - - print() + gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) # parallel scan Y = pscan(A, X, Y0) - # for k in range(A.size(1)): - # print(f"{k} -> {Y[:,k]}") - s = Y.sum() - # print(s) - print(torch.autograd.grad(s, A, retain_graph=True)) - print(torch.autograd.grad(s, X, retain_graph=True)) - print(torch.autograd.grad(s, Y0, retain_graph=True)) + gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) + + print((gA - gA_ref).norm()) + print((gX - gX_ref).norm()) + print((gY0 - gY0_ref).norm()) -- 2.20.1