+ # [:, -2] < [:, -1]
+ m = (V[:, -2] - s >= V[:, -1]).long()
+ V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+ m = m[:, None]
+ X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+
+
+######################################################################
+
+
+def pscan_diff(X, V, s=1):
+ if X.size(1) == 1:
+ return
+
+ T = 2 * (X.size(1) // 2)
+
+ Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
+ Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
+
+ # [:, :, 0] < [:, :, 1]
+ m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
+ Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
+ m = m[:, :, None]
+ Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+
+ pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
+
+ # [:, :-1, 1] < [:, 1:, 0]
+ m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+ Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
+ m = m[:, :, None]
+ Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+
+ if T < X.size(1):
+ # [:, -2] < [:, -1]