X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=07e376a3af86d24874523775e26f16132a489512;hb=HEAD;hp=97de6d108daf435553fe8cb22c3c427a1af391fe;hpb=00b2d5ed01fb523fbc4e699f0419329efbee0ea8;p=picoclvr.git diff --git a/graph.py b/graph.py index 97de6d1..07e376a 100755 --- a/graph.py +++ b/graph.py @@ -11,23 +11,43 @@ import cairo ###################################################################### + + def save_attention_image( + # image to save filename, - tokens, - attention, - surface_width=128, - surface_height=96, + tokens_input, + tokens_output, + # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 + attention_matrices, + # do not draw links with a lesser attention + min_link_attention=0, + # draw only the strongest links necessary so that their summed + # attention is above min_total_attention + min_total_attention=None, + # draw only the top k links + k_top=None, + # the purely graphical settings + curved=True, pixel_scale=8, - x=10, - y=10, token_gap=15, layer_gap=25, - y_eps=1, - min_att=1e-2, + y_eps=0.5, + padding=10, ): - # surface = cairo.PDFSurface( - # filename, surface_width * pixel_scale, surface_height * pixel_scale - # ) + if k_top is not None: + am = [] + for m in attention_matrices: + am.append(m * (m.sort(dim=-1, descending=True).indices < k_top)) + attention_matrices = am + + if min_total_attention is not None: + am = [] + for m in attention_matrices: + s = m.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() + b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m) + am.append(m * b) surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) @@ -38,9 +58,53 @@ def save_attention_image( ctx.set_font_size(4.0) # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) - u = [] - for n, t in enumerate(tokens): - string = str(t) + x, y = 0, 0 + + ctx.set_line_width(0.25) + for d in range(len(attention_matrices)): + at = attention_matrices[d].to("cpu") + ni = torch.arange(at.size(0))[:, None].expand_as(at) + nj = torch.arange(at.size(1))[None, :].expand_as(at) + at = at.flatten() + o = at.sort().indices + at = at[o] + ni = ni.flatten()[o] + nj = nj.flatten()[o] + for i, j, a in zip(ni, nj, at): + if a > 0 and a >= min_link_attention: + c = 1 - a.item() + ctx.set_source_rgb(c, c, c) + ax, ay = j * token_gap, y - y_eps + ctx.move_to(ax, ay) + dx, dy = i * token_gap, y - layer_gap + y_eps + if curved: + bx, by = ax, ay - layer_gap * 0.5 + cx, cy = dx, dy + layer_gap * 0.5 + ctx.curve_to(bx, by, cx, cy, dx, dy) + else: + ctx.line_to(dx, dy) + ctx.stroke() + y -= layer_gap + + for d in range(0, len(attention_matrices) + 1): + n = ( + attention_matrices[0].size(-1) + if d == 0 + else attention_matrices[d - 1].size(-2) + ) + for n in range(n): + xc, yc = n * token_gap, -d * layer_gap + ctx.set_source_rgb(1.0, 1.0, 1.0) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi) + ctx.fill() + + ctx.set_source_rgb(0.0, 0.0, 0.0) + + for k, t in enumerate(tokens_input): + s = str(t) ( x_bearing, y_bearing, @@ -48,32 +112,31 @@ def save_attention_image( height_t, x_advance, y_advance, - ) = ctx.text_extents(string) - u.append((n, string, x, x + width_t / 2, height_t, y_bearing)) - x += x_advance + token_gap - tokens = u - - for d in range(attention.size(0) + 1): - for n, s, x, xc, h, yb in tokens: - # ctx.set_source_rgb(0.0, 0.0, 0.0) - # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) - # ctx.stroke() - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.move_to(x, y) - ctx.show_text(s) - # x += x_advance + 1 - if d < attention.size(0): - for m, _, _, x2c, h2, y2b in tokens: - if attention[d, n, m] >= min_att: - c = 1 - attention[d, n, m] - ctx.set_source_rgb(c, c, c) - ctx.set_line_width(0.5) - ctx.move_to(xc, y + yb + h + y_eps) - ctx.line_to(x2c, y + layer_gap + y2b - y_eps) - ctx.stroke() - y += layer_gap + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) + ctx.show_text(s) + + for k, t in enumerate(tokens_output): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to( + k * token_gap - width_t / 2, + -token_gap / 5 - len(attention_matrices) * layer_gap, + ) + ctx.show_text(s) x, y, width, height = surface.ink_extents() + x -= padding + y -= padding + width += 2 * padding + height += 2 * padding pdf_surface = cairo.PDFSurface(filename, width, height) ctx_pdf = cairo.Context(pdf_surface) ctx_pdf.set_source_surface(surface, -x, -y) @@ -86,8 +149,11 @@ def save_attention_image( if __name__ == "__main__": import mygpt + tokens_output = ["", "-", 3, 4, ""] + tokens_input = [""] + tokens_output[:-1] + vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, 5)) + x = torch.randint(vocabulary_size, (1, len(tokens_input))) model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -95,7 +161,7 @@ if __name__ == "__main__": dim_keys=2, dim_hidden=2, nb_heads=2, - nb_blocks=3, + nb_blocks=5, dropout=0.1, causal=True, ) @@ -105,11 +171,15 @@ if __name__ == "__main__": y1 = model(mygpt.BracketedSequence(x)).x - a = model.retrieve_attention() - print(a) - attention = torch.cat([x[:0] for x in a], dim=0) + attention_matrices = [m[0, 0] for m in model.retrieve_attention()] - tokens = ["bluh", 2, 3, 4, "blih"] - attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1) + # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] - save_attention_image("attention.pdf", tokens, attention) + save_attention_image( + "attention.pdf", + tokens_input, + tokens_output, + attention_matrices, + # k_top=2, + min_total_attention=0.9, + )