Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 08:06:45 +0000 (10:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 08:06:45 +0000 (10:06 +0200)
graph.py [new file with mode: 0755]
mygpt.py
tasks.py

diff --git a/graph.py b/graph.py
new file mode 100755 (executable)
index 0000000..97de6d1
--- /dev/null
+++ b/graph.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+import cairo
+
+
+######################################################################
+def save_attention_image(
+    filename,
+    tokens,
+    attention,
+    surface_width=128,
+    surface_height=96,
+    pixel_scale=8,
+    x=10,
+    y=10,
+    token_gap=15,
+    layer_gap=25,
+    y_eps=1,
+    min_att=1e-2,
+):
+    # surface = cairo.PDFSurface(
+    # filename, surface_width * pixel_scale, surface_height * pixel_scale
+    # )
+
+    surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
+
+    ctx = cairo.Context(surface)
+    ctx.scale(pixel_scale, pixel_scale)
+
+    ctx.set_source_rgb(0.0, 0.0, 0.0)
+    ctx.set_font_size(4.0)
+    # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+
+    u = []
+    for n, t in enumerate(tokens):
+        string = str(t)
+        (
+            x_bearing,
+            y_bearing,
+            width_t,
+            height_t,
+            x_advance,
+            y_advance,
+        ) = ctx.text_extents(string)
+        u.append((n, string, x, x + width_t / 2, height_t, y_bearing))
+        x += x_advance + token_gap
+    tokens = u
+
+    for d in range(attention.size(0) + 1):
+        for n, s, x, xc, h, yb in tokens:
+            # ctx.set_source_rgb(0.0, 0.0, 0.0)
+            # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
+            # ctx.stroke()
+            ctx.set_source_rgb(0.0, 0.0, 0.0)
+            ctx.move_to(x, y)
+            ctx.show_text(s)
+            # x += x_advance + 1
+            if d < attention.size(0):
+                for m, _, _, x2c, h2, y2b in tokens:
+                    if attention[d, n, m] >= min_att:
+                        c = 1 - attention[d, n, m]
+                        ctx.set_source_rgb(c, c, c)
+                        ctx.set_line_width(0.5)
+                        ctx.move_to(xc, y + yb + h + y_eps)
+                        ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
+                        ctx.stroke()
+        y += layer_gap
+
+    x, y, width, height = surface.ink_extents()
+    pdf_surface = cairo.PDFSurface(filename, width, height)
+    ctx_pdf = cairo.Context(pdf_surface)
+    ctx_pdf.set_source_surface(surface, -x, -y)
+    ctx_pdf.paint()
+    pdf_surface.finish()
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import mygpt
+
+    vocabulary_size = 3
+    x = torch.randint(vocabulary_size, (1, 5))
+
+    model = mygpt.MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=4,
+        dim_keys=2,
+        dim_hidden=2,
+        nb_heads=2,
+        nb_blocks=3,
+        dropout=0.1,
+        causal=True,
+    )
+
+    model.eval()
+    model.record_attention()
+
+    y1 = model(mygpt.BracketedSequence(x)).x
+
+    a = model.retrieve_attention()
+    print(a)
+    attention = torch.cat([x[:0] for x in a], dim=0)
+
+    tokens = ["bluh", 2, 3, 4, "blih"]
+    attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
+
+    save_attention_image("attention.pdf", tokens, attention)
index 45b7b59..ac1c55e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -116,7 +116,13 @@ class AddPositionalEncoding(nn.Module):
 
 class QKVAttention(nn.Module):
     def __init__(
-        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        causal=False,
+        attention_dropout=0.0,
     ):
         super().__init__()
 
@@ -125,6 +131,7 @@ class QKVAttention(nn.Module):
 
         self.causal = causal
         self.attention_dropout = attention_dropout
+        self.record_attention = False
 
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
@@ -162,6 +169,9 @@ class QKVAttention(nn.Module):
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
         ) / math.sqrt(self.w_q.size(1))
 
+        if self.record_attention:
+            self.a = a
+
         if self.causal:
             if bs_q.first == 0:
                 self.cache_attzero = (
@@ -283,6 +293,18 @@ class MyGPT(nn.Module):
                 t_next = dist.sample()
             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
+    def record_attention(self, v=True):
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                m.record_attention = v
+
+    def retrieve_attention(self):
+        a = []
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                a.append(m.a)
+        return a
+
 
 ######################################################################
 
@@ -298,13 +320,12 @@ if __name__ == "__main__":
         dim_keys=2,
         dim_hidden=2,
         nb_heads=2,
-        nb_blocks=1,
+        nb_blocks=2,
         dropout=0.1,
         causal=True,
     )
 
     model.eval()
-
     y1 = model(BracketedSequence(x)).x
     y2 = torch.randn_like(y1)
     for s in range(x.size(1)):
index ca71182..af71b85 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1111,6 +1111,7 @@ class RPL(Task):
         self.test_input = self.tensorize(test_sequences)
 
         if no_prog:
+            # Excise the program from every train and test example
             k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
                 None, :
             ]
@@ -1185,13 +1186,13 @@ class RPL(Task):
             )
 
             sum_nb_total, sum_nb_errors = 0, 0
-            for x, y in zip(input, result):
-                seq = [self.id2token[i.item()] for i in y]
+            for one_input, one_result in zip(input, result):
+                seq = [self.id2token[i.item()] for i in one_result]
                 nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
                 sum_nb_total += 1
                 sum_nb_errors += 0 if nb_errors == 0 else 1
                 if nb_to_log > 0:
-                    gt_seq = [self.id2token[i.item()] for i in x]
+                    gt_seq = [self.id2token[i.item()] for i in one_input]
                     _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
                     gt_prog = " ".join([str(x) for x in gt_prog])
                     prog = " ".join([str(x) for x in prog])
@@ -1232,14 +1233,20 @@ class RPL(Task):
             )
 
             sum_nb_total, sum_nb_errors = 0, 0
-            for x, y, i, j in zip(input, result, last_output_idx, first_prog_idx):
-                seq = [self.id2token[i.item()] for i in y]
+            for one_input, one_result, i, j in zip(
+                input, result, last_output_idx, first_prog_idx
+            ):
+                seq = [self.id2token[i.item()] for i in one_result]
                 sum_nb_total += 1
-                correct = (x - y).abs().max() == 0
+                correct = (one_input - one_result).abs().max() == 0
                 sum_nb_errors += 0 if correct else 1
                 if nb_to_log > 0:
-                    result_stack = [self.id2token[i.item()] for i in y[i : j + 1]]
-                    target_stack = [self.id2token[i.item()] for i in x[i : j + 1]]
+                    result_stack = [
+                        self.id2token[i.item()] for i in one_result[i : j + 1]
+                    ]
+                    target_stack = [
+                        self.id2token[i.item()] for i in one_input[i : j + 1]
+                    ]
                     comment = "*" if correct else "-"
                     result_stack = " ".join([str(x) for x in result_stack])
                     target_stack = " ".join([str(x) for x in target_stack])