Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 12:29:26 +0000 (14:29 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 12:29:26 +0000 (14:29 +0200)
graph.py

index 97de6d1..0db7bd0 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -15,14 +15,11 @@ def save_attention_image(
     filename,
     tokens,
     attention,
-    surface_width=128,
-    surface_height=96,
     pixel_scale=8,
-    x=10,
-    y=10,
-    token_gap=15,
+    token_gap=10,
     layer_gap=25,
-    y_eps=1,
+    y_eps=1.5,
+    padding=0,
     min_att=1e-2,
 ):
     # surface = cairo.PDFSurface(
@@ -38,7 +35,9 @@ def save_attention_image(
     ctx.set_font_size(4.0)
     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
 
-    u = []
+    x, y = 0, 0
+
+    u = {}
     for n, t in enumerate(tokens):
         string = str(t)
         (
@@ -49,21 +48,14 @@ def save_attention_image(
             x_advance,
             y_advance,
         ) = ctx.text_extents(string)
-        u.append((n, string, x, x + width_t / 2, height_t, y_bearing))
+        u[n]=(string, x, x + width_t / 2, height_t, y_bearing)
         x += x_advance + token_gap
     tokens = u
 
     for d in range(attention.size(0) + 1):
-        for n, s, x, xc, h, yb in tokens:
-            # ctx.set_source_rgb(0.0, 0.0, 0.0)
-            # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
-            # ctx.stroke()
-            ctx.set_source_rgb(0.0, 0.0, 0.0)
-            ctx.move_to(x, y)
-            ctx.show_text(s)
-            # x += x_advance + 1
+        for n, (s, x, xc, h, yb) in tokens.items():
             if d < attention.size(0):
-                for m, _, _, x2c, h2, y2b in tokens:
+                for m, (_, _, x2c, h2, y2b) in tokens.items():
                     if attention[d, n, m] >= min_att:
                         c = 1 - attention[d, n, m]
                         ctx.set_source_rgb(c, c, c)
@@ -71,9 +63,20 @@ def save_attention_image(
                         ctx.move_to(xc, y + yb + h + y_eps)
                         ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
                         ctx.stroke()
+            # ctx.set_source_rgb(0.0, 0.0, 0.0)
+            # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
+            # ctx.stroke()
+            ctx.set_source_rgb(0.0, 0.0, 0.0)
+            ctx.move_to(x, y)
+            ctx.show_text(s)
+            # x += x_advance + 1
         y += layer_gap
 
     x, y, width, height = surface.ink_extents()
+    x -= padding
+    y -= padding
+    width += 2 * padding
+    height += 2 * padding
     pdf_surface = cairo.PDFSurface(filename, width, height)
     ctx_pdf = cairo.Context(pdf_surface)
     ctx_pdf.set_source_surface(surface, -x, -y)
@@ -112,4 +115,4 @@ if __name__ == "__main__":
     tokens = ["bluh", 2, 3, 4, "blih"]
     attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
 
-    save_attention_image("attention.pdf", tokens, attention)
+    save_attention_image("attention.pdf", tokens, attention, padding=3)