- m = (V[:, t] >= W[:, t - 1] - 1).long()
- Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1]
- W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1)
+ m = (W[:, t - 1] - 1 >= V[:, t]).long()
+ W[:, t] = m * (W[:, t - 1] - 1) + (1 - m) * V[:, t]
+ Y[:, t] = m * Y[:, t - 1] + (1 - m) * (
+ X[:, t] * (1 + dv) + Y[:, t - 1] * dv0
+ )
+
+ return Y, W
+
+
+######################################################################
+
+
+def hs(x):
+ return x.sigmoid() # (x >= 0).float() + (x - x.detach()) * (x < 0).float()
+
+
+def baseline(X, V):
+ for t in range(X.size(1)):
+ if t == 0:
+ Y = X[:, t]
+ W = V[:, t]
+ else:
+ m = (W - 1 - V[:, t]).sigmoid()
+ # m = hs(W - 1 - V[:, t])
+ W = m * (W - 1) + (1 - m) * V[:, t]
+ Y = m * Y + (1 - m) * X[:, t]