From: François Fleuret Date: Mon, 18 Dec 2023 01:57:06 +0000 (+0100) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=59513fa775776af525477f01925f563ddffecdb2 Update. --- diff --git a/pscan.py b/pscan.py index 1dfb442..071f284 100755 --- a/pscan.py +++ b/pscan.py @@ -77,39 +77,31 @@ pscan = PScan.apply if __name__ == "__main__": N, T, D = 2, 5, 3 - # Iterative implementation - A = torch.randn(N, T, dtype=torch.float64).requires_grad_() X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() + # Iterative implementation + y = Y0 s = 0 for k in range(A.size(1)): y = A[:, k, None] * y + X[:, k] s = s + y - # print(f"{k} -> {y}") s = s.sum() - # print(s) - print(torch.autograd.grad(s, A, retain_graph=True)) - print(torch.autograd.grad(s, X, retain_graph=True)) - print(torch.autograd.grad(s, Y0, retain_graph=True)) - - print() + gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) # parallel scan Y = pscan(A, X, Y0) - # for k in range(A.size(1)): - # print(f"{k} -> {Y[:,k]}") - s = Y.sum() - # print(s) - print(torch.autograd.grad(s, A, retain_graph=True)) - print(torch.autograd.grad(s, X, retain_graph=True)) - print(torch.autograd.grad(s, Y0, retain_graph=True)) + gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) + + print((gA - gA_ref).norm()) + print((gX - gX_ref).norm()) + print((gY0 - gY0_ref).norm())