Update.
[mygptrnn.git] / maxval.py
1 #!/usr/bin/env python
2
3 import torch
4
5 ######################################################################
6
7
8 def baseline1(X, V):
9     Y = X.new(X.size())
10     W = V.new(V.size())
11
12     for t in range(X.size(1)):
13         if t == 0:
14             Y[:, t] = X[:, t]
15             W[:, t] = V[:, t]
16         else:
17             m = (W[:, t - 1] - 1 >= V[:, t]).long()
18             W[:, t] = m * (W[:, t - 1] - 1) + (1 - m) * V[:, t]
19             Y[:, t] = m * Y[:, t - 1] + (1 - m) * (
20                 X[:, t] * (1 + dv) + Y[:, t - 1] * dv0
21             )
22
23     return Y, W
24
25
26 ######################################################################
27
28
29 def hs(x):
30     return x.sigmoid()  # (x >= 0).float() + (x - x.detach()) * (x < 0).float()
31
32
33 def baseline(X, V):
34     for t in range(X.size(1)):
35         if t == 0:
36             Y = X[:, t]
37             W = V[:, t]
38         else:
39             m = (W - 1 - V[:, t]).sigmoid()
40             # m = hs(W - 1 - V[:, t])
41             W = m * (W - 1) + (1 - m) * V[:, t]
42             Y = m * Y + (1 - m) * X[:, t]
43
44     return Y, W
45
46
47 ######################################################################
48
49
50 def pscan(X, V, s=1):
51     if X.size(1) == 1:
52         return
53
54     T = 2 * (X.size(1) // 2)
55
56     Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
57     Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
58
59     # [:, :, 0] < [:, :, 1]
60     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
61     Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
62     m = m[:, :, None]
63     Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
64
65     pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
66
67     # [:, :-1, 1] < [:, 1:, 0]
68     m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
69     Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
70     m = m[:, :, None]
71     Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
72
73     if T < X.size(1):
74         # [:, -2] < [:, -1]
75         m = (V[:, -2] - s >= V[:, -1]).long()
76         V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
77         m = m[:, None]
78         X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
79
80
81 ######################################################################
82
83
84 def pscan_diff(X, V, s=1):
85     if X.size(1) == 1:
86         return X, V
87
88     T = 2 * (X.size(1) // 2)
89
90     Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
91     Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
92
93     Xr = X.new(X.size())
94     Vr = V.new(V.size())
95     Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2))
96     Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
97
98     # [:, :, 0] < [:, :, 1]
99     dv0 = (Vf[:, :, 0] - Vf[:, :, 0].detach())[:, :, None]
100     dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None]
101     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
102     Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
103     m = m[:, :, None]
104     Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + Xf[:, :, 0] * dv0)
105
106     Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
107
108     # [:, :-1, 1] < [:, 1:, 0]
109     dv0 = (Vrf[:, :-1, 1] - Vrf[:, :-1, 1].detach())[:, :, None]
110     dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None]
111     m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
112     Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
113     m = m[:, :, None]
114     Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (
115         Xf[:, 1:, 0] * (1 + dv) + Xrf[:, :-1, 1] * dv0
116     )
117
118     Xr[:, 0] = X[:, 0]
119     Vr[:, 0] = V[:, 0]
120
121     if T < X.size(1):
122         # [:, -2] < [:, -1]
123         dx = X[:, -2] - X[:, -2].detach()
124         dv = (V[:, -1] - V[:, -1].detach())[:, None]
125         m = (V[:, -2] - s >= V[:, -1]).long()
126         Vr[:, -1] = m * (Vr[:, -2] - s) + (1 - m) * V[:, -1]
127         m = m[:, None]
128         Xr[:, -1] = m * Xr[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx)
129
130     return Xr, Vr
131
132
133 ######################################################################
134
135 if __name__ == "__main__":
136     N = 1
137     T = 64
138     D = 128
139
140     torch.autograd.set_detect_anomaly(True)
141
142     for k in range(0):
143         X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
144         V = torch.rand(N, T, dtype=torch.float64)
145
146         X0, V0 = baseline(X, V)
147
148         # print("########### X0 V0 ###########################################")
149         # print(V0)
150         # print(X0)
151
152         X1, V1 = pscan_diff(X, V)
153
154         # print("########### X V ############################################")
155         # print(V)
156         # print(X)
157
158         error = ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()
159         if error > 0:
160             print("ERROR", error)
161             print(X0)
162             print(X1)
163             exit(0)
164
165     # exit(0)
166
167     # s = X1.sum()
168     # print(torch.autograd.grad(s, X))
169
170     # with open("/tmp/v.dat", "w") as f:
171     # for t in range(T):
172     # f.write(f"{V1[0,t].item()}\n")
173
174     Y = torch.randn(1, 1, D)
175     X = torch.randn(N, T, D) * 0.1
176
177     m = (torch.rand(N, T, 1).sort(dim=1).indices == 0).float()
178     X = (1 - m) * X + m * Y
179     V = torch.rand(N, T)  # + 100* m.squeeze(dim=-1)
180     V = V.requires_grad_()
181
182     optimizer = torch.optim.SGD([V], lr=1e-1)
183
184     for k in range(1000):
185         X1, V1 = baseline(X, V)
186         loss = (X1 - Y).pow(2).mean()
187         print(k, loss.item())
188         optimizer.zero_grad()
189         loss.backward()
190         optimizer.step()