Update.
[mygptrnn.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14
15
16 def save_attention_image(
17     # image to save
18     filename,
19     tokens_input,
20     tokens_output,
21     # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1
22     attention_matrices,
23     # do not draw links with a lesser attention
24     min_link_attention=0,
25     # draw only the strongest links necessary so that their summed
26     # attention is above min_total_attention
27     min_total_attention=None,
28     # draw only the top k links
29     k_top=None,
30     # the purely graphical settings
31     curved=True,
32     pixel_scale=8,
33     token_gap=15,
34     layer_gap=25,
35     y_eps=0.5,
36     padding=10,
37 ):
38     if k_top is not None:
39         am = []
40         for m in attention_matrices:
41             am.append(m * (m.sort(dim=-1, descending=True).indices < k_top))
42         attention_matrices = am
43
44     if min_total_attention is not None:
45         am = []
46         for m in attention_matrices:
47             s = m.sort(dim=-1)
48             m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
49             b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m)
50             am.append(m * b)
51
52     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
53
54     ctx = cairo.Context(surface)
55     ctx.scale(pixel_scale, pixel_scale)
56
57     ctx.set_source_rgb(0.0, 0.0, 0.0)
58     ctx.set_font_size(4.0)
59     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
60
61     x, y = 0, 0
62
63     ctx.set_line_width(0.25)
64     for d in range(len(attention_matrices)):
65         at = attention_matrices[d].to("cpu")
66         ni = torch.arange(at.size(0))[:, None].expand_as(at)
67         nj = torch.arange(at.size(1))[None, :].expand_as(at)
68         at = at.flatten()
69         o = at.sort().indices
70         at = at[o]
71         ni = ni.flatten()[o]
72         nj = nj.flatten()[o]
73         for i, j, a in zip(ni, nj, at):
74             if a > 0 and a >= min_link_attention:
75                 c = 1 - a.item()
76                 ctx.set_source_rgb(c, c, c)
77                 ax, ay = j * token_gap, y - y_eps
78                 ctx.move_to(ax, ay)
79                 dx, dy = i * token_gap, y - layer_gap + y_eps
80                 if curved:
81                     bx, by = ax, ay - layer_gap * 0.5
82                     cx, cy = dx, dy + layer_gap * 0.5
83                     ctx.curve_to(bx, by, cx, cy, dx, dy)
84                 else:
85                     ctx.line_to(dx, dy)
86                 ctx.stroke()
87         y -= layer_gap
88
89     for d in range(0, len(attention_matrices) + 1):
90         n = (
91             attention_matrices[0].size(-1)
92             if d == 0
93             else attention_matrices[d - 1].size(-2)
94         )
95         for n in range(n):
96             xc, yc = n * token_gap, -d * layer_gap
97             ctx.set_source_rgb(1.0, 1.0, 1.0)
98             ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
99             ctx.fill()
100             ctx.set_source_rgb(0.0, 0.0, 0.0)
101             ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
102             ctx.fill()
103
104     ctx.set_source_rgb(0.0, 0.0, 0.0)
105
106     for k, t in enumerate(tokens_input):
107         s = str(t)
108         (
109             x_bearing,
110             y_bearing,
111             width_t,
112             height_t,
113             x_advance,
114             y_advance,
115         ) = ctx.text_extents(s)
116         ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5)
117         ctx.show_text(s)
118
119     for k, t in enumerate(tokens_output):
120         s = str(t)
121         (
122             x_bearing,
123             y_bearing,
124             width_t,
125             height_t,
126             x_advance,
127             y_advance,
128         ) = ctx.text_extents(s)
129         ctx.move_to(
130             k * token_gap - width_t / 2,
131             -token_gap / 5 - len(attention_matrices) * layer_gap,
132         )
133         ctx.show_text(s)
134
135     x, y, width, height = surface.ink_extents()
136     x -= padding
137     y -= padding
138     width += 2 * padding
139     height += 2 * padding
140     pdf_surface = cairo.PDFSurface(filename, width, height)
141     ctx_pdf = cairo.Context(pdf_surface)
142     ctx_pdf.set_source_surface(surface, -x, -y)
143     ctx_pdf.paint()
144     pdf_surface.finish()
145
146
147 ######################################################################
148
149 if __name__ == "__main__":
150     import mygpt
151
152     tokens_output = ["<wat>", "-", 3, 4, "<end>"]
153     tokens_input = [""] + tokens_output[:-1]
154
155     vocabulary_size = 3
156     x = torch.randint(vocabulary_size, (1, len(tokens_input)))
157
158     model = mygpt.MyGPT(
159         vocabulary_size=vocabulary_size,
160         dim_model=4,
161         dim_keys=2,
162         dim_hidden=2,
163         nb_heads=2,
164         nb_blocks=5,
165         dropout=0.1,
166         causal=True,
167     )
168
169     model.eval()
170     model.record_attention()
171
172     y1 = model(mygpt.BracketedSequence(x)).x
173
174     attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
175
176     # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]]
177
178     save_attention_image(
179         "attention.pdf",
180         tokens_input,
181         tokens_output,
182         attention_matrices,
183         # k_top=2,
184         min_total_attention=0.9,
185     )