From 3da672d90cfade894c2dbe87cc9058c02e4d19ea Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 27 Jan 2024 18:07:04 +0100 Subject: [PATCH] Update. --- maxval.py | 61 +++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/maxval.py b/maxval.py index 747f4b9..a202b17 100755 --- a/maxval.py +++ b/maxval.py @@ -25,13 +25,14 @@ def baseline(X, V): def pscan(X, V, s=1): if X.size(1) == 1: - return X, V + return T = 2 * (X.size(1) // 2) Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2)) Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2) + # [:, :, 0] < [:, :, 1] m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] m = m[:, :, None] @@ -39,12 +40,48 @@ def pscan(X, V, s=1): pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2) - m = (Vf[:, 1:, 0] >= Vf[:, :-1, 1] - s).long() - Vf[:, 1:, 0] = m * Vf[:, 1:, 0] + (1 - m) * (Vf[:, :-1, 1] - s) + # [:, :-1, 1] < [:, 1:, 0] + m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() + Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] m = m[:, :, None] - Xf[:, 1:, 0] = m * Xf[:, 1:, 0] + (1 - m) * Xf[:, :-1, 1] + Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] if T < X.size(1): + # [:, -2] < [:, -1] + m = (V[:, -2] - s >= V[:, -1]).long() + V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] + m = m[:, None] + X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1] + + +###################################################################### + + +def pscan_diff(X, V, s=1): + if X.size(1) == 1: + return + + T = 2 * (X.size(1) // 2) + + Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2)) + Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2) + + # [:, :, 0] < [:, :, 1] + m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long() + Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1] + m = m[:, :, None] + Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1] + + pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2) + + # [:, :-1, 1] < [:, 1:, 0] + m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long() + Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0] + m = m[:, :, None] + Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0] + + if T < X.size(1): + # [:, -2] < [:, -1] m = (V[:, -2] - s >= V[:, -1]).long() V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1] m = m[:, None] @@ -58,23 +95,25 @@ if __name__ == "__main__": T = 513 D = 2 - # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() - # V = torch.rand(N, T, dtype=torch.float64) * 50 + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + V = torch.rand(N, T, dtype=torch.float64) * 10 - # X0, V0 = baseline(X, V) + X0, V0 = baseline(X, V) # print("########### X0 V0 ###########################################") # print(V0) # print(X0) - # X1, V1 = X.clone(), V.clone() - # pscan(X1, V1) + X1, V1 = X.clone(), V.clone() + pscan_diff(X1, V1) # print("########### X V ############################################") # print(V) # print(X) - # print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()) + print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()) + + exit(0) # s = X1.sum() # print(torch.autograd.grad(s, X)) @@ -93,7 +132,7 @@ if __name__ == "__main__": for k in range(1000): X1, V1 = X.clone(), V.clone() - pscan(X1, V1) + pscan(X, V, X1, V1) # X1=X1*(1+V1-V1.detach())[:,:,None] loss = (X1[:, -1:] - Y).pow(2).mean() print(k, loss.item()) -- 2.20.1