Initial commit
[mygptrnn.git] / pscan.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch
9
10 ######################################################################
11
12
13 class PScan(torch.autograd.Function):
14     # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
15     # and O(log(T)) if not core-bounded, so that
16     #
17     # Y[:, 0] = Y_init
18     # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
19     #
20     # can be computed as
21     #
22     # Y[:, t] = A[:, t] * Y_init + X[:, t]
23
24     @staticmethod
25     def expand_(A, X):
26         if A.size(1) == 1:
27             return
28         T = 2 * (A.size(1) // 2)
29         Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
30         Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
31         Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
32         Aa[:, :, 1].mul_(Aa[:, :, 0])
33         PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
34         Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
35         Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
36         if T < A.size(1):
37             X[:, -1].add_(A[:, -1].mul(X[:, -2]))
38             A[:, -1].mul_(A[:, -2])
39
40     @staticmethod
41     def acc_rev_(A, X):
42         if X.size(1) == 1:
43             return
44         T = 2 * (X.size(1) // 2)
45         Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
46         Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
47         Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
48         B = Aa[:, :, 0].clone()
49         B[:, 1:].mul_(Aa[:, :-1, 1])
50         PScan.acc_rev_(B, Xa[:, :, 0])
51         Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
52         if T < A.size(1):
53             X[:, 0].add_(A[:, 1].mul(X[:, 1]))
54
55     # A is NxT, X is NxTxD, Y_init is NxD
56     #
57     # returns Y of same shape as X, with
58     #
59     # Y[:, t] = A[:, 0] * Y_init   + X[:, 0] if t == 0
60     #         = A[:, t] * Y[:, t-1] + X[:, t] otherwise
61
62     @staticmethod
63     def forward(ctx, A, X, Y_init):
64         ctx.A = A.unsqueeze(-1).clone()
65         ctx.Y_init = Y_init[:, None].clone()
66         ctx.A_star = ctx.A.clone()
67         ctx.X_star = X.clone()
68         PScan.expand_(ctx.A_star, ctx.X_star)
69         return ctx.A_star * ctx.Y_init + ctx.X_star
70
71     @staticmethod
72     def backward(ctx, grad_output):
73         U = grad_output * ctx.A_star
74         A = ctx.A.clone()
75         R = grad_output.clone()
76         PScan.acc_rev_(A, R)
77         Q = ctx.Y_init.expand_as(ctx.X_star).clone()
78         Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1])
79         return (Q * R).sum(-1), R, U.sum(dim=1)
80
81
82 pscan = PScan.apply
83
84 ######################################################################
85
86 if __name__ == "__main__":
87     import time, sys
88
89     A = torch.rand(17, 12, 3)
90     X = torch.rand(17, 12, 3, 11)
91     Y_init = torch.rand(17, 3, 11)
92     Y = pscan(A, X, Y_init)
93     exit(0)
94
95     N, T, D = 2, 1047, 3
96
97     A = torch.rand(N, T, dtype=torch.float64).requires_grad_()
98     X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
99     Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
100
101     # Iterative implementation
102
103     y = Y_init
104     s = 0
105
106     for k in range(A.size(1)):
107         y = A[:, k, None] * y + X[:, k]
108         s = s + y
109
110     s = s.sum()
111
112     gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
113         s, (A, X, Y_init), retain_graph=True
114     )
115
116     # parallel scan
117
118     start_time = time.perf_counter()
119     for _ in range(1000):
120         Y = pscan(A, X, Y_init)
121     duration = time.perf_counter() - start_time
122     print(f"duration {duration}")
123
124     s = Y.sum()
125
126     gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
127
128     # print(gA)
129     # print(gX)
130     # print(gY_init)
131
132     print((gA - gA_ref).norm())
133     print((gX - gX_ref).norm())
134     print((gY_init - gY_init_ref).norm())
135
136     Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
137     Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
138
139     print((Y - torch.cat([Y1, Y2], dim=1)).norm())