Update.
[mygptrnn.git] / pscan.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch
9
10 ######################################################################
11
12
13 class PScan(torch.autograd.Function):
14     # Given A is NxTxMx1 and X is NxTxMxD, expands A and X in
15     # place in O(T), and O(log(T)) if not core-bounded, so that
16     #
17     # Y[:, 0] = Y_init
18     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
19     #
20     # can be computed as
21     #
22     # Y[:, t] = A[:, t] * Y_init + X[:, t]
23
24     @staticmethod
25     def expand_(A, X):
26         # Unrolling gains ~8% speed
27
28         if A.size(1) > 4:
29             T = 2 * (A.size(1) // 2)
30             Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
31             Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
32             Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
33             Aa[:, :, 1].mul_(Aa[:, :, 0])
34             PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
35             Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
36             Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
37             if T < A.size(1):
38                 X[:, -1].add_(A[:, -1].mul(X[:, -2]))
39                 A[:, -1].mul_(A[:, -2])
40         elif A.size(1) == 2:
41             X[:, 1].add_(A[:, 1].mul(X[:, 0]))
42             A[:, 1].mul_(A[:, 0])
43         elif A.size(1) == 3:
44             X[:, 1].add_(A[:, 1].mul(X[:, 0]))
45             A[:, 1].mul_(A[:, 0])
46             X[:, 2].add_(A[:, 2].mul(X[:, 1]))
47             A[:, 2].mul_(A[:, 1])
48         elif A.size(1) == 4:
49             X[:, 1].add_(A[:, 1].mul(X[:, 0]))
50             A[:, 1].mul_(A[:, 0])
51             X[:, 2].add_(A[:, 2].mul(X[:, 1]))
52             A[:, 2].mul_(A[:, 1])
53             X[:, 3].add_(A[:, 3].mul(X[:, 2]))
54             A[:, 3].mul_(A[:, 2])
55
56     @staticmethod
57     def acc_rev_(A, X):
58         if A.size(1) > 4:
59             T = 2 * (X.size(1) // 2)
60             Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
61             Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
62             Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
63             B = Aa[:, :, 0].clone()
64             B[:, 1:].mul_(Aa[:, :-1, 1])
65             PScan.acc_rev_(B, Xa[:, :, 0])
66             Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
67             if T < A.size(1):
68                 X[:, 0].add_(A[:, 1].mul(X[:, 1]))
69         elif A.size(1) == 2:
70             X[:, 0].add_(A[:, 1].mul(X[:, 1]))
71         elif A.size(1) == 3:
72             X[:, 1].add_(A[:, 2].mul(X[:, 2]))
73             X[:, 0].add_(A[:, 1].mul(X[:, 1]))
74         elif A.size(1) == 4:
75             X[:, 2].add_(A[:, 3].mul(X[:, 3]))
76             X[:, 1].add_(A[:, 2].mul(X[:, 2]))
77             X[:, 0].add_(A[:, 1].mul(X[:, 1]))
78
79     # A is NxT, X is NxTxD, Y_init is NxD
80     #
81     # returns Y of same shape as X, with
82     #
83     # Y[:, t] = A[:, 0] * Y_init   + X[:, 0] if t == 0
84     #         = A[:, t] * Y[:, t-1] + X[:, t] otherwise
85
86     @staticmethod
87     def forward(ctx, A, X, Y_init):
88         ctx.A = A.unsqueeze(-1).clone()
89         ctx.Y_init = Y_init[:, None].clone()
90         ctx.A_star = ctx.A.clone()
91         ctx.X_star = X.clone()
92         PScan.expand_(ctx.A_star, ctx.X_star)
93         return ctx.A_star * ctx.Y_init + ctx.X_star
94
95     @staticmethod
96     def backward(ctx, grad_output):
97         U = grad_output * ctx.A_star
98         A = ctx.A.clone()
99         R = grad_output.clone()
100         PScan.acc_rev_(A, R)
101         Q = ctx.Y_init.expand_as(ctx.X_star).clone()
102         Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1])
103         return (Q * R).sum(-1), R, U.sum(dim=1)
104
105
106 pscan = PScan.apply
107
108
109 def naive_pscan(A, X, Y_init):
110     y = Y_init
111     s = 0
112
113     for k in range(A.size(1)):
114         y = A[:, k, None] * y + X[:, k]
115         s = s + y
116
117     s = s.sum()
118
119
120 ######################################################################
121
122 if __name__ == "__main__":
123     import time, sys
124
125     ######################################################################
126
127     N, T, D = 16, 4096, 32
128
129     for r in range(timing.size(0)):
130         A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
131         X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
132         Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
133
134         start_time = time.perf_counter()
135         for _ in range(1000):
136             Y = pscan(A, X, Y_init)
137         duration = time.perf_counter() - start_time
138
139     ######################################################################
140
141     # A = torch.rand(17, 12, 3)
142     # X = torch.rand(17, 12, 3, 11)
143     # Y_init = torch.rand(17, 3, 11)
144     # Y = pscan(A, X, Y_init)
145
146     # exit(0)
147
148     err = 0
149     timing = torch.empty(10)
150
151     for r in range(timing.size(0)):
152         N, T, D = 2, 1120, 3
153
154         # T = torch.randint(10, (1,)).item() + 1
155
156         A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
157         X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
158         Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
159
160         # Iterative implementation
161
162         y = Y_init
163         s = 0
164
165         for k in range(A.size(1)):
166             y = A[:, k, None] * y + X[:, k]
167             s = s + y
168
169         s = s.sum()
170
171         gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
172             s, (A, X, Y_init), retain_graph=True
173         )
174
175         # parallel scan
176
177         start_time = time.perf_counter()
178         for _ in range(1000):
179             Y = pscan(A, X, Y_init)
180         duration = time.perf_counter() - start_time
181
182         print(f"duration {duration}")
183         timing[r] = duration
184
185         s = Y.sum()
186
187         gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
188
189         err = max(
190             [
191                 err,
192                 (gA - gA_ref).abs().max(),
193                 (gX - gX_ref).abs().max(),
194                 (gY_init - gY_init_ref).abs().max(),
195             ]
196         )
197
198         # Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
199         # Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
200
201         # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max())
202
203     print(f"err={err:.2e} duration={timing.mean():.2e} (+/- {timing.std():.2e})")