97de6d108daf435553fe8cb22c3c427a1af391fe
[picoclvr.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14 def save_attention_image(
15     filename,
16     tokens,
17     attention,
18     surface_width=128,
19     surface_height=96,
20     pixel_scale=8,
21     x=10,
22     y=10,
23     token_gap=15,
24     layer_gap=25,
25     y_eps=1,
26     min_att=1e-2,
27 ):
28     # surface = cairo.PDFSurface(
29     # filename, surface_width * pixel_scale, surface_height * pixel_scale
30     # )
31
32     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
33
34     ctx = cairo.Context(surface)
35     ctx.scale(pixel_scale, pixel_scale)
36
37     ctx.set_source_rgb(0.0, 0.0, 0.0)
38     ctx.set_font_size(4.0)
39     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
40
41     u = []
42     for n, t in enumerate(tokens):
43         string = str(t)
44         (
45             x_bearing,
46             y_bearing,
47             width_t,
48             height_t,
49             x_advance,
50             y_advance,
51         ) = ctx.text_extents(string)
52         u.append((n, string, x, x + width_t / 2, height_t, y_bearing))
53         x += x_advance + token_gap
54     tokens = u
55
56     for d in range(attention.size(0) + 1):
57         for n, s, x, xc, h, yb in tokens:
58             # ctx.set_source_rgb(0.0, 0.0, 0.0)
59             # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
60             # ctx.stroke()
61             ctx.set_source_rgb(0.0, 0.0, 0.0)
62             ctx.move_to(x, y)
63             ctx.show_text(s)
64             # x += x_advance + 1
65             if d < attention.size(0):
66                 for m, _, _, x2c, h2, y2b in tokens:
67                     if attention[d, n, m] >= min_att:
68                         c = 1 - attention[d, n, m]
69                         ctx.set_source_rgb(c, c, c)
70                         ctx.set_line_width(0.5)
71                         ctx.move_to(xc, y + yb + h + y_eps)
72                         ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
73                         ctx.stroke()
74         y += layer_gap
75
76     x, y, width, height = surface.ink_extents()
77     pdf_surface = cairo.PDFSurface(filename, width, height)
78     ctx_pdf = cairo.Context(pdf_surface)
79     ctx_pdf.set_source_surface(surface, -x, -y)
80     ctx_pdf.paint()
81     pdf_surface.finish()
82
83
84 ######################################################################
85
86 if __name__ == "__main__":
87     import mygpt
88
89     vocabulary_size = 3
90     x = torch.randint(vocabulary_size, (1, 5))
91
92     model = mygpt.MyGPT(
93         vocabulary_size=vocabulary_size,
94         dim_model=4,
95         dim_keys=2,
96         dim_hidden=2,
97         nb_heads=2,
98         nb_blocks=3,
99         dropout=0.1,
100         causal=True,
101     )
102
103     model.eval()
104     model.record_attention()
105
106     y1 = model(mygpt.BracketedSequence(x)).x
107
108     a = model.retrieve_attention()
109     print(a)
110     attention = torch.cat([x[:0] for x in a], dim=0)
111
112     tokens = ["bluh", 2, 3, 4, "blih"]
113     attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
114
115     save_attention_image("attention.pdf", tokens, attention)