- # print(s)
- print(torch.autograd.grad(s, A, retain_graph=True))
- print(torch.autograd.grad(s, X, retain_graph=True))
- print(torch.autograd.grad(s, Y0, retain_graph=True))
+ gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
+
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY_init - gY_init_ref).norm())
+
+ Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+ Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+
+ print((Y - torch.cat([Y1, Y2], dim=1)).norm())