a2554d2937e119af473d83d2c95e0b85e853a9ef
[picoclvr.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14
15
16 def save_attention_image(
17     filename,
18     tokens_input,
19     tokens_output,
20     # An iterable set of BxHxTxT attention matrices
21     attention_arrays,
22     n_sample=0,
23     n_head=0,
24     pixel_scale=8,
25     token_gap=15,
26     layer_gap=25,
27     y_eps=0.5,
28     padding=10,
29     # do not draw links with a lesser attention
30     min_link_attention=0,
31     # draw only the strongest links necessary to reache
32     # min_total_attention
33     min_total_attention=None,
34     # draw only the top k links
35     k_top=None,
36     curved=True,
37 ):
38     attention = torch.cat(
39         [x[n_sample : n_sample + 1, n_head] for x in attention_arrays], dim=0
40     )
41
42     if k_top is not None:
43         attention = attention * (
44             attention.sort(dim=-1, descending=True).indices < k_top
45         )
46
47     if min_total_attention is not None:
48         s = attention.sort(dim=-1)
49         m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
50         b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m)
51         attention = attention * b
52
53     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
54
55     ctx = cairo.Context(surface)
56     ctx.scale(pixel_scale, pixel_scale)
57
58     ctx.set_source_rgb(0.0, 0.0, 0.0)
59     ctx.set_font_size(4.0)
60     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
61
62     x, y = 0, 0
63
64     for d in range(attention.size(0)):
65         at = attention[d]
66         ni = torch.arange(at.size(0))[:, None].expand_as(at)
67         nj = torch.arange(at.size(1))[None, :].expand_as(at)
68         at = at.flatten()
69         o = at.sort().indices
70         at = at[o]
71         ni = ni.flatten()[o]
72         nj = nj.flatten()[o]
73         for i, j, a in zip(ni, nj, at):
74             if a > 0 and a >= min_link_attention:
75                 c = 1 - a.item()
76                 ctx.set_source_rgb(c, c, c)
77                 ctx.set_line_width(0.5)
78                 ax, ay = j * token_gap, y - y_eps
79                 ctx.move_to(ax, ay)
80                 dx, dy = i * token_gap, y - layer_gap + y_eps
81                 if curved:
82                     bx, by = ax, ay - layer_gap * 0.5
83                     cx, cy = dx, dy + layer_gap * 0.5
84                     ctx.curve_to(bx, by, cx, cy, dx, dy)
85                 else:
86                     ctx.line_to(dx, dy)
87                 ctx.stroke()
88         y -= layer_gap
89
90     for d in range(0, attention.size(0) + 1):
91         for n in range(attention.size(-1)):
92             xc, yc = n * token_gap, -d * layer_gap
93             ctx.set_source_rgb(1.0, 1.0, 1.0)
94             ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
95             ctx.fill()
96             ctx.set_source_rgb(0.0, 0.0, 0.0)
97             ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
98             ctx.fill()
99
100     ctx.set_source_rgb(0.0, 0.0, 0.0)
101
102     for k, t in enumerate(tokens_input):
103         s = str(t)
104         (
105             x_bearing,
106             y_bearing,
107             width_t,
108             height_t,
109             x_advance,
110             y_advance,
111         ) = ctx.text_extents(s)
112         ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
113         ctx.show_text(s)
114
115     for k, t in enumerate(tokens_output):
116         s = str(t)
117         (
118             x_bearing,
119             y_bearing,
120             width_t,
121             height_t,
122             x_advance,
123             y_advance,
124         ) = ctx.text_extents(s)
125         ctx.move_to(
126             k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
127         )
128         ctx.show_text(s)
129
130     x, y, width, height = surface.ink_extents()
131     x -= padding
132     y -= padding
133     width += 2 * padding
134     height += 2 * padding
135     pdf_surface = cairo.PDFSurface(filename, width, height)
136     ctx_pdf = cairo.Context(pdf_surface)
137     ctx_pdf.set_source_surface(surface, -x, -y)
138     ctx_pdf.paint()
139     pdf_surface.finish()
140
141
142 ######################################################################
143
144 if __name__ == "__main__":
145     import mygpt
146
147     tokens_output = ["<wat>", 2, 3, 4, "<end>"]
148     tokens_input = [""] + tokens_output[:-1]
149
150     vocabulary_size = 3
151     x = torch.randint(vocabulary_size, (1, len(tokens_input)))
152
153     model = mygpt.MyGPT(
154         vocabulary_size=vocabulary_size,
155         dim_model=4,
156         dim_keys=2,
157         dim_hidden=2,
158         nb_heads=2,
159         nb_blocks=3,
160         dropout=0.1,
161         causal=True,
162     )
163
164     model.eval()
165     model.record_attention()
166
167     y1 = model(mygpt.BracketedSequence(x)).x
168
169     attention = model.retrieve_attention()
170
171     save_attention_image(
172         "attention.pdf",
173         tokens_input,
174         tokens_output,
175         attention,
176         # k_top=2,
177         min_total_attention=0.9,
178     )