nb_heads=2,
nb_blocks=5,
dropout=0.1,
- causal=True,
+ #causal=True,
)
model.eval()
attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
+
+
# attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
# for a in attention_matrices: a=a/a.sum(-1,keepdim=True)