1 % -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; mode: yas/minor; coding: utf-8; tex-command: "pdflatex.sh" -*-
3 % Any copyright is dedicated to the Public Domain.
4 % https://creativecommons.org/publicdomain/zero/1.0/
6 % Written by Francois Fleuret <francois@fleuret.org>
8 \documentclass[c,8pt]{beamer}
10 \setbeamertemplate{navigation symbols}{}
12 \def\transpose{^{\top}}
13 \def\softmax{\operatorname{softmax}}
15 \definecolor{blue}{rgb}{0.0,0.0,0.55}
16 \definecolor{green}{rgb}{0.0,0.50,0.0}
17 \definecolor{bluegray}{rgb}{0.1,0.2,0.7}
19 \setbeamercolor{math text}{fg=bluegray}
20 \setbeamercolor{local structure}{fg=blue}
24 \usetikzlibrary{positioning,fit,backgrounds}
25 \usetikzlibrary{arrows.meta,decorations.pathreplacing}
27 \usetikzlibrary{shapes,calc,intersections}
28 \usetikzlibrary{patterns}
30 \usetikzlibrary{arrows}
32 \definecolor{nn-data} {rgb}{0.90, 0.95, 1.00}
33 \definecolor{nn-param} {rgb}{1.00, 0.90, 0.50}
34 \definecolor{nn-process}{rgb}{0.80, 1.00, 0.80}
35 \tikzset{>={Straight Barb[angle'=80,scale=1.1]}}
38 value/.style ={ font=\scriptsize, rectangle, draw=black!50, fill=white, thick,
39 inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt },
40 parameter/.style={ font=\scriptsize, rectangle, draw=black!50, fill=blue!15, thick,
41 inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt },
42 operation/.style={ font=\scriptsize, rectangle, draw=black!50, fill=green!30, thick,
43 inner sep=3pt, minimum size=10pt, minimum height=20pt },
44 flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick},
46 f2f/.style={draw=black!50, thick},
47 v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick},
48 f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick},
49 v2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black!50, thick},
52 df2f/.style={draw=black, thick},
53 dv2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black, thick},
54 df2v/.style={->,shorten >= 0.75pt,draw=black, thick},
55 dv2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black, thick},
57 differential/.style ={ font=\small, rectangle, draw=black!50, thick,
58 inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt, fill=yellow!80 },
59 dflow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black, thick}
62 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
72 \node[value, minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
73 \node[value, minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$};
74 \node[value, minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$};
75 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
76 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
77 \node[value, minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
78 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
79 \node[value, minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
81 \draw[v2f,rounded corners=1mm] (K) -- (att);
82 \draw[v2f,rounded corners=1mm] (Q) -| (att);
83 \draw[f2f,rounded corners=1mm] (att) -- (sm);
84 \draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
86 \draw[v2f,rounded corners=1mm] (A) -- (prod);
87 \draw[v2f,rounded corners=1mm] (V) -| (prod);
88 \draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
90 \draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
91 \draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
92 \draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
93 \draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
95 \draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
96 \draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
97 \draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
98 \draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
100 \draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
101 \draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
106 A & = \softmax_{row} \left( \frac{Q K\transpose}{\sqrt{D}} \right) \\
110 Single-head attention
116 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
124 \node[value, minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
125 \draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
126 \draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
128 \node[value, minimum height=1.2cm,minimum width=0.7cm] (Q) [above=1cm of K] {$Q$};
129 \draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
130 \draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
132 \node[value, minimum height=0.8cm,minimum width=1.0cm] (V) [below=1cm of K] {$V$};
133 \draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
134 \draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
136 \node[operation,minimum height=0.4cm] (mulWq) [left=1cm of Q.center] {$\cdot$};
137 \node[operation,minimum height=0.4cm] (mulWk) [left=1cm of K.center] {$\cdot$};
138 \node[operation,minimum height=0.4cm] (mulWv) [left=1cm of V.center] {$\cdot$};
140 \node[value, minimum height=1.2cm,minimum width=0.5cm] (X) [left=1cm of mulWq] {$X$};
141 \draw[very thick,cyan] ([xshift=-1pt]X.north west) -- ([xshift=-1pt]X.south west);
142 \draw[very thick,green] ([yshift=1pt]X.north west) -- ([yshift=1pt]X.north east);
144 \node[parameter, minimum height=0.5cm,minimum width=0.7cm] (Wq) [above=0.25 cm of X] {$W^Q$};
145 \draw[very thick,green] ([xshift=-1pt]Wq.north west) -- ([xshift=-1pt]Wq.south west);
146 \draw[very thick,yellow] ([yshift=1pt]Wq.north west) -- ([yshift=1pt]Wq.north east);
148 \node[value, minimum height=0.8cm,minimum width=0.3cm] (X') [below=1.2cm of X] {$X'$};
150 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
151 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
153 \node[value, minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
154 \draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
155 \draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
157 \node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
159 \node[value, minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
160 \draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
161 \draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
163 \draw[v2f,rounded corners=1mm] (X) -- (mulWq);
164 \draw[v2f,rounded corners=1mm] (Wq) -| (mulWq);
165 \draw[f2v,rounded corners=1mm] (mulWq) -- ([xshift=-1pt]Q.west);
167 \draw[v2f,rounded corners=1mm] (K) -- (att);
168 \draw[v2f,rounded corners=1mm] (Q) -| (att);
169 \draw[f2f,rounded corners=1mm] (att) -- (sm);
170 \draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
172 \draw[v2f,rounded corners=1mm] (A) -- (prod);
173 \draw[v2f,rounded corners=1mm] (V) -| (prod);
174 \draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);