- # print(s)
- print(torch.autograd.grad(s, A, retain_graph=True))
- print(torch.autograd.grad(s, X, retain_graph=True))
- print(torch.autograd.grad(s, Y0, retain_graph=True))
+ gA, gX, gY0 = torch.autograd.grad(s, (A, X, Y0), retain_graph=True)
+
+ print((gA - gA_ref).norm())
+ print((gX - gX_ref).norm())
+ print((gY0 - gY0_ref).norm())