Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jan 2024 17:50:30 +0000 (18:50 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jan 2024 17:50:30 +0000 (18:50 +0100)
maxval.py

index 31aa4b6..a245287 100755 (executable)
--- a/maxval.py
+++ b/maxval.py
@@ -72,10 +72,12 @@ def pscan_diff(X, V, s=1):
     Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
 
     # [:, :, 0] < [:, :, 1]
+    dx = Xf[:, :, 1] - Xf[:, :, 1].detach()
+    dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None]
     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
     Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
     m = m[:, :, None]
-    Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+    Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + dx)
 
     Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
 
@@ -83,17 +85,21 @@ def pscan_diff(X, V, s=1):
     Vr[:, 0] = V[:, 0]
 
     # [:, :-1, 1] < [:, 1:, 0]
+    dx = Xf[:, 1:, 0] - Xf[:, 1:, 0].detach()
+    dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None]
     m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
     Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
     m = m[:, :, None]
-    Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+    Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (Xf[:, 1:, 0] * (1 + dv) + dx)
 
     if T < X.size(1):
         # [:, -2] < [:, -1]
+        dx = X[:, -1] - X[:, -1].detach()
+        dv = (V[:, -1] - V[:, -1].detach())[:, None]
         m = (V[:, -2] - s >= V[:, -1]).long()
         Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
         m = m[:, None]
-        Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+        Xr[:, -1] = m * X[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx)
 
     return Xr, Vr