08f1170b9dfc4838340e972d8731ace834967246
[picoclvr.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14
15
16 def save_attention_image(
17     filename,  # image to save
18     tokens_input,
19     tokens_output,
20     attention_matrices,  # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk
21     # do not draw links with a lesser attention
22     min_link_attention=0,
23     # draw only the strongest links necessary to reache
24     # min_total_attention
25     min_total_attention=None,
26     # draw only the top k links
27     k_top=None,
28     curved=True,
29     pixel_scale=8,
30     token_gap=15,
31     layer_gap=25,
32     y_eps=0.5,
33     padding=10,
34 ):
35     if k_top is not None:
36         am = []
37         for m in attention_matrices:
38             am.append(m * (m.sort(dim=-1, descending=True).indices < k_top))
39         attention_matrices = am
40
41     if min_total_attention is not None:
42         am = []
43         for m in attention_matrices:
44             s = m.sort(dim=-1)
45             m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
46             b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m)
47             am.append(m * b)
48
49     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
50
51     ctx = cairo.Context(surface)
52     ctx.scale(pixel_scale, pixel_scale)
53
54     ctx.set_source_rgb(0.0, 0.0, 0.0)
55     ctx.set_font_size(4.0)
56     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
57
58     x, y = 0, 0
59
60     ctx.set_line_width(0.25)
61     for d in range(len(attention_matrices)):
62         at = attention_matrices[d]
63         ni = torch.arange(at.size(0))[:, None].expand_as(at)
64         nj = torch.arange(at.size(1))[None, :].expand_as(at)
65         at = at.flatten()
66         o = at.sort().indices
67         at = at[o]
68         ni = ni.flatten()[o]
69         nj = nj.flatten()[o]
70         for i, j, a in zip(ni, nj, at):
71             if a > 0 and a >= min_link_attention:
72                 c = 1 - a.item()
73                 ctx.set_source_rgb(c, c, c)
74                 ax, ay = j * token_gap, y - y_eps
75                 ctx.move_to(ax, ay)
76                 dx, dy = i * token_gap, y - layer_gap + y_eps
77                 if curved:
78                     bx, by = ax, ay - layer_gap * 0.5
79                     cx, cy = dx, dy + layer_gap * 0.5
80                     ctx.curve_to(bx, by, cx, cy, dx, dy)
81                 else:
82                     ctx.line_to(dx, dy)
83                 ctx.stroke()
84         y -= layer_gap
85
86     for d in range(0, len(attention_matrices) + 1):
87         n = (
88             attention_matrices[0].size(-1)
89             if d == 0
90             else attention_matrices[d - 1].size(-2)
91         )
92         for n in range(n):
93             xc, yc = n * token_gap, -d * layer_gap
94             ctx.set_source_rgb(1.0, 1.0, 1.0)
95             ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
96             ctx.fill()
97             ctx.set_source_rgb(0.0, 0.0, 0.0)
98             ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
99             ctx.fill()
100
101     ctx.set_source_rgb(0.0, 0.0, 0.0)
102
103     for k, t in enumerate(tokens_input):
104         s = str(t)
105         (
106             x_bearing,
107             y_bearing,
108             width_t,
109             height_t,
110             x_advance,
111             y_advance,
112         ) = ctx.text_extents(s)
113         ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
114         ctx.show_text(s)
115
116     for k, t in enumerate(tokens_output):
117         s = str(t)
118         (
119             x_bearing,
120             y_bearing,
121             width_t,
122             height_t,
123             x_advance,
124             y_advance,
125         ) = ctx.text_extents(s)
126         ctx.move_to(
127             k * token_gap - width_t / 2,
128             -token_gap / 5 - len(attention_matrices) * layer_gap,
129         )
130         ctx.show_text(s)
131
132     x, y, width, height = surface.ink_extents()
133     x -= padding
134     y -= padding
135     width += 2 * padding
136     height += 2 * padding
137     pdf_surface = cairo.PDFSurface(filename, width, height)
138     ctx_pdf = cairo.Context(pdf_surface)
139     ctx_pdf.set_source_surface(surface, -x, -y)
140     ctx_pdf.paint()
141     pdf_surface.finish()
142
143
144 ######################################################################
145
146 if __name__ == "__main__":
147     import mygpt
148
149     tokens_output = ["<wat>", 2, 3, 4, "<end>"]
150     tokens_input = [""] + tokens_output[:-1]
151
152     vocabulary_size = 3
153     x = torch.randint(vocabulary_size, (1, len(tokens_input)))
154
155     model = mygpt.MyGPT(
156         vocabulary_size=vocabulary_size,
157         dim_model=4,
158         dim_keys=2,
159         dim_hidden=2,
160         nb_heads=2,
161         nb_blocks=5,
162         dropout=0.1,
163         causal=True,
164     )
165
166     model.eval()
167     model.record_attention()
168
169     y1 = model(mygpt.BracketedSequence(x)).x
170
171     attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
172
173     # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
174     # for a in attention_matrices: a=a/a.sum(-1,keepdim=True)
175
176     save_attention_image(
177         "attention.pdf",
178         tokens_input,
179         tokens_output,
180         attention_matrices,
181         # k_top=2,
182         min_total_attention=0.9,
183     )