From ce9e6321c97778f5aa7d2ecaaff640a3e45bef13 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 29 Aug 2023 08:37:08 +0200 Subject: [PATCH] Update. --- graph.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/graph.py b/graph.py index 6db9ed7..07e376a 100755 --- a/graph.py +++ b/graph.py @@ -14,10 +14,12 @@ import cairo def save_attention_image( - filename, # image to save + # image to save + filename, tokens_input, tokens_output, - attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk + # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 + attention_matrices, # do not draw links with a lesser attention min_link_attention=0, # draw only the strongest links necessary so that their summed @@ -25,6 +27,7 @@ def save_attention_image( min_total_attention=None, # draw only the top k links k_top=None, + # the purely graphical settings curved=True, pixel_scale=8, token_gap=15, @@ -170,8 +173,7 @@ if __name__ == "__main__": attention_matrices = [m[0, 0] for m in model.retrieve_attention()] - # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] - # for a in attention_matrices: a=a/a.sum(-1,keepdim=True) + # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] save_attention_image( "attention.pdf", -- 2.20.1