From 674dd7c7adde6b4a9aaa5afd57dbe1d063a47fcc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 18 Dec 2023 04:52:50 +0100 Subject: [PATCH] Update. --- pscan.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/pscan.py b/pscan.py index 071f284..3526c31 100755 --- a/pscan.py +++ b/pscan.py @@ -14,12 +14,12 @@ class PScan(torch.autograd.Function): # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), # and O(log(T)) if not core-bounded, so that # - # Y[:, 0] = Y0 + # Y[:, 0] = Y_init # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] # # can be computed as # - # Y[:, t] = A[:, t] * Y0 + X[:, t] + # Y[:, t] = A[:, t] * Y_init + X[:, t] @staticmethod def expand(A, X): @@ -51,21 +51,28 @@ class PScan(torch.autograd.Function): if T < X.size(1): X[:, 0].add_(X[:, 1]) + # A is NxT, X is NxTxD, Y_init is NxD + # + # returns Y of same shape as X, with + # + # Y[:,t] = A[:,0] * Y_init + X[:,0] if t == 0 + # = A[:,t] * Y[:,t-1] + X[:,t] otherwise + @staticmethod - def forward(ctx, A, X, Y0): + def forward(ctx, A, X, Y_init): ctx.A = A[:, :, None].clone() - ctx.Y0 = Y0[:, None, :].clone() + ctx.Y_init = Y_init[:, None, :].clone() ctx.A_star = A[:, :, None].clone() ctx.X_star = X.clone() PScan.expand(ctx.A_star, ctx.X_star) - return ctx.A_star * ctx.Y0 + ctx.X_star + return ctx.A_star * ctx.Y_init + ctx.X_star @staticmethod def backward(ctx, grad_output): U = grad_output * ctx.A_star R = U.clone() PScan.accrev(R) - Q = ctx.Y0 / ctx.A + Q = ctx.Y_init / ctx.A Q[:, 1:].add_(ctx.X_star[:, :-1] / ctx.A_star[:, 1:]) return (Q * R).sum(-1), R / ctx.A_star, U.sum(dim=1) @@ -79,11 +86,11 @@ if __name__ == "__main__": A = torch.randn(N, T, dtype=torch.float64).requires_grad_() X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() - Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() # Iterative implementation - y = Y0 + y = Y_init s = 0 for k in range(A.size(1)): @@ -92,16 +99,18 @@ if __name__ == "__main__": s = s.sum() - gA_ref, gX_ref, gY0_ref = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) + gA_ref, gX_ref, gY_init_ref = torch.autograd.grad( + s, (A, X, Y_init), retain_graph=True + ) # parallel scan - Y = pscan(A, X, Y0) + Y = pscan(A, X, Y_init) s = Y.sum() - gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True) + gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True) print((gA - gA_ref).norm()) print((gX - gX_ref).norm()) - print((gY0 - gY0_ref).norm()) + print((gY_init - gY_init_ref).norm()) -- 2.20.1