Update.
[mygptrnn.git] / maxval.py
1 #!/usr/bin/env python
2
3 import torch
4
5 ######################################################################
6
7
8 def baseline(X, V):
9     Y = X.new(X.size())
10     W = V.new(V.size())
11     for t in range(X.size(1)):
12         if t == 0:
13             Y[:, t] = X[:, t]
14             W[:, t] = V[:, t]
15         else:
16             m = (V[:, t] >= W[:, t - 1] - 1).long()
17             Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1]
18             W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1)
19
20     return Y, W
21
22
23 ######################################################################
24
25
26 def pscan(X, V, s=1):
27     if X.size(1) == 1:
28         return
29
30     T = 2 * (X.size(1) // 2)
31
32     Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
33     Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
34
35     # [:, :, 0] < [:, :, 1]
36     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
37     Vf[:, :, 1] = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
38     m = m[:, :, None]
39     Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
40
41     pscan(Xf[:, :, 1], Vf[:, :, 1], s * 2)
42
43     # [:, :-1, 1] < [:, 1:, 0]
44     m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
45     Vf[:, 1:, 0] = m * (Vf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
46     m = m[:, :, None]
47     Xf[:, 1:, 0] = m * Xf[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
48
49     if T < X.size(1):
50         # [:, -2] < [:, -1]
51         m = (V[:, -2] - s >= V[:, -1]).long()
52         V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
53         m = m[:, None]
54         X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
55
56
57 ######################################################################
58
59
60 def pscan_diff(X, V, s=1):
61     if X.size(1) == 1:
62         return X, V
63
64     T = 2 * (X.size(1) // 2)
65
66     Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
67     Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
68
69     Xr = X.new(X.size())
70     Vr = V.new(V.size())
71     Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2))
72     Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
73
74     # [:, :, 0] < [:, :, 1]
75     dx = Xf[:, :, 1] - Xf[:, :, 1].detach()
76     dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None]
77     m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
78     Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
79     m = m[:, :, None]
80     Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + dx)
81
82     Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
83
84     Xr[:, 0] = X[:, 0]
85     Vr[:, 0] = V[:, 0]
86
87     # [:, :-1, 1] < [:, 1:, 0]
88     dx = Xf[:, 1:, 0] - Xf[:, 1:, 0].detach()
89     dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None]
90     m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
91     Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
92     m = m[:, :, None]
93     Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (Xf[:, 1:, 0] * (1 + dv) + dx)
94
95     if T < X.size(1):
96         # [:, -2] < [:, -1]
97         dx = X[:, -1] - X[:, -1].detach()
98         dv = (V[:, -1] - V[:, -1].detach())[:, None]
99         m = (V[:, -2] - s >= V[:, -1]).long()
100         Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
101         m = m[:, None]
102         Xr[:, -1] = m * X[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx)
103
104     return Xr, Vr
105
106
107 ######################################################################
108
109 if __name__ == "__main__":
110     N = 1
111     T = 513
112     D = 2
113
114     X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
115     V = torch.rand(N, T, dtype=torch.float64) * 10
116
117     X0, V0 = baseline(X, V)
118
119     # print("########### X0 V0 ###########################################")
120     # print(V0)
121     # print(X0)
122
123     X1, V1 = pscan_diff(X, V)
124
125     # print("########### X V ############################################")
126     # print(V)
127     # print(X)
128
129     print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
130
131     exit(0)
132
133     # s = X1.sum()
134     # print(torch.autograd.grad(s, X))
135
136     # with open("/tmp/v.dat", "w") as f:
137     # for t in range(T):
138     # f.write(f"{V1[0,t].item()}\n")
139
140     Y = torch.randn(1, 1, D)
141     X = torch.randn(
142         N, T, D
143     )  # * 0.1 + (torch.rand(N,T,1).sort(dim=1).indices==0).float() * Y
144     V = torch.rand(N, T).requires_grad_()
145
146     optimizer = torch.optim.SGD([V], lr=1e-2)
147
148     for k in range(1000):
149         X1, V1 = X.clone(), V.clone()
150         pscan(X, V, X1, V1)
151         # X1=X1*(1+V1-V1.detach())[:,:,None]
152         loss = (X1[:, -1:] - Y).pow(2).mean()
153         print(k, loss.item())
154         optimizer.zero_grad()
155         loss.backward()
156         optimizer.step()