From: François Fleuret Date: Sat, 16 Dec 2023 13:54:40 +0000 (-0600) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=8c12767fe586074920e3d4abb05e4393a145351a Update. --- diff --git a/pscan.py b/pscan.py new file mode 100755 index 0000000..36490ff --- /dev/null +++ b/pscan.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +def naive_rec(A, X, Y0): + Y = [] + for t in range(X.size(1)): + if t == 0: + Y.append(A[:, t] * Y0 + X[:, t]) + else: + Y.append(A[:, t] * Y[-1] + X[:, t]) + + return torch.cat([y[:, None, :] for y in Y], dim=1) + + +###################################################################### + +# A is NxTx1 and X is NxTxD +# +# Returns Y defined with +# +# Y[:, 0] = A[:, 0] * Y0 + X[:,0] +# for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t] + + +def pscan_rec(A, X, Y0): + if X.size(1) % 2 == 1: + if X.size(1) == 1: + return A[:, :1] * Y0[:, None] + X[:, :1] + else: + Y = pscan_rec(A[:, :-1], X[:, :-1], Y0) + return torch.cat([Y, A[:, -1:] * Y[:, -1:] + X[:, -1:]], dim=1) + + A2 = A.reshape(A.size(0), A.size(1) // 2, 2, A.size(2)) + X2 = X.reshape(X.size(0), X.size(1) // 2, 2, X.size(2)) + + X_star = X2[:, :, 0].clone() + X_star[:, 1:] += A2[:, 1:, 0] * X2[:, :-1, 1] + + A_star = A2[:, :, 0].clone() + A_star[:, 1:] *= A2[:, :-1, 1] + + Y_star = pscan_rec(A_star, X_star, Y0)[:, :, None] + + Y = torch.cat([Y_star, A2[:, :, 1, None] * Y_star + X2[:, :, 1, None]], dim=2) + + Y = Y.reshape(Y.size(0), -1, Y.size(-1)) + + return Y + + +###################################################################### + +N, T, D = 5, 29, 12 + +A = torch.rand(N, T, 1, dtype=torch.float64) +X = torch.randint(10, (N, T, D), dtype=torch.float64) +Y0 = torch.randint(10, (N, D), dtype=torch.float64) + +naive_Y = naive_rec(A, X, Y0) + +pscan_Y = pscan_rec(A, X, Y0) + +print((naive_Y - pscan_Y).pow(2).mean()) + +pscan_Y1 = pscan_rec(A[:, :15], X[:, :15], Y0) +pscan_Y2 = pscan_rec(A[:, 15:], X[:, 15:], pscan_Y1[:, -1]) + +print((naive_Y - torch.cat([pscan_Y1, pscan_Y2], dim=1)).pow(2).mean())