-def check_causality(model):
- #m = model[1:]
- input = torch.rand(1, 5, dim_model).requires_grad_()
- output = m(input)
- a = torch.zeros(output.size(1), input.size(1))
- for k in range(output.size(1)):
- for d in range(output.size(2)):
- g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
- a[k] += g.squeeze(0).pow(2).sum(1)
- print(a)
-
-######################################################################
-