Update.
[picoclvr.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14 def save_attention_image(
15     filename,
16     tokens,
17     attention,
18     pixel_scale=8,
19     token_gap=10,
20     layer_gap=25,
21     y_eps=1.5,
22     padding=0,
23     min_att=1e-2,
24 ):
25     # surface = cairo.PDFSurface(
26     # filename, surface_width * pixel_scale, surface_height * pixel_scale
27     # )
28
29     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
30
31     ctx = cairo.Context(surface)
32     ctx.scale(pixel_scale, pixel_scale)
33
34     ctx.set_source_rgb(0.0, 0.0, 0.0)
35     ctx.set_font_size(4.0)
36     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
37
38     x, y = 0, 0
39
40     u = {}
41     for n, t in enumerate(tokens):
42         string = str(t)
43         (
44             x_bearing,
45             y_bearing,
46             width_t,
47             height_t,
48             x_advance,
49             y_advance,
50         ) = ctx.text_extents(string)
51         u[n]=(string, x, x + width_t / 2, height_t, y_bearing)
52         x += x_advance + token_gap
53     tokens = u
54
55     for d in range(attention.size(0) + 1):
56         for n, (s, x, xc, h, yb) in tokens.items():
57             if d < attention.size(0):
58                 for m, (_, _, x2c, h2, y2b) in tokens.items():
59                     if attention[d, n, m] >= min_att:
60                         c = 1 - attention[d, n, m]
61                         ctx.set_source_rgb(c, c, c)
62                         ctx.set_line_width(0.5)
63                         ctx.move_to(xc, y + yb + h + y_eps)
64                         ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
65                         ctx.stroke()
66             # ctx.set_source_rgb(0.0, 0.0, 0.0)
67             # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
68             # ctx.stroke()
69             ctx.set_source_rgb(0.0, 0.0, 0.0)
70             ctx.move_to(x, y)
71             ctx.show_text(s)
72             # x += x_advance + 1
73         y += layer_gap
74
75     x, y, width, height = surface.ink_extents()
76     x -= padding
77     y -= padding
78     width += 2 * padding
79     height += 2 * padding
80     pdf_surface = cairo.PDFSurface(filename, width, height)
81     ctx_pdf = cairo.Context(pdf_surface)
82     ctx_pdf.set_source_surface(surface, -x, -y)
83     ctx_pdf.paint()
84     pdf_surface.finish()
85
86
87 ######################################################################
88
89 if __name__ == "__main__":
90     import mygpt
91
92     vocabulary_size = 3
93     x = torch.randint(vocabulary_size, (1, 5))
94
95     model = mygpt.MyGPT(
96         vocabulary_size=vocabulary_size,
97         dim_model=4,
98         dim_keys=2,
99         dim_hidden=2,
100         nb_heads=2,
101         nb_blocks=3,
102         dropout=0.1,
103         causal=True,
104     )
105
106     model.eval()
107     model.record_attention()
108
109     y1 = model(mygpt.BracketedSequence(x)).x
110
111     a = model.retrieve_attention()
112     print(a)
113     attention = torch.cat([x[:0] for x in a], dim=0)
114
115     tokens = ["bluh", 2, 3, 4, "blih"]
116     attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
117
118     save_attention_image("attention.pdf", tokens, attention, padding=3)