Initial commit
[tex.git] / attention.tex
1 % -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; -*-
2
3 \documentclass[c,8pt]{beamer}
4
5 \usepackage{tikz}
6 \newcommand{\transpose}{^{\top}}
7 \def\softmax{\operatorname{softmax}}
8
9 \setbeamertemplate{navigation symbols}{}
10
11 \begin{document}
12
13 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
14
15 \begin{frame}[fragile]
16
17 Given a query sequence $Q$, a key sequence $K$, and a value sequence
18 $V$, compute an attention matrix $A$ by matching $Q$s to $K$s, and
19 weight $V$ with it to get $Y$.
20
21 \medskip
22
23 \[
24 \uncover<2,4,6->{
25   A_i = \softmax \left( \frac{Q_i \, K\transpose}{\sqrt{d}} \right)
26 }
27 %
28 \quad \quad \quad
29 %
30 \uncover<3,5->{
31   Y_i = A_i V
32 }
33 \]
34
35 \medskip
36
37 \makebox[\textwidth][c]{
38 \begin{tikzpicture}
39
40   \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (V) at (-2, 2.35) {
41     \begin{tikzpicture}
42       \draw[fill=green!20] (0, 0) rectangle (4, 1.4);
43       \uncover<3,5>{\draw[fill=yellow] (0, 0) rectangle (4, 1.4);}
44       \foreach \x in { 0.2, 0.4, ..., 3.8 } \draw (\x, 0) -- ++(0, 1.4);
45     \end{tikzpicture}
46   };
47
48   \node[cm={1.0, 0.0, 0.5, 0.5, (0.0, 0.0)}] (A) at (0.5, 1.6) {
49     \begin{tikzpicture}
50       \draw (0, 0) rectangle ++(3, 4);
51     \end{tikzpicture}
52   };
53
54   \uncover<2-3>{
55   \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (a1) at (-0.9, 2.1) {
56     \begin{tikzpicture}
57       \draw[draw=none] (0, 0) rectangle (4, 1);
58       \foreach \x/\y in {
59         0.00/0.03, 0.20/0.04, 0.40/0.07, 0.60/0.35, 0.80/0.52,
60         1.00/1.00, 1.20/0.82, 1.40/0.25, 1.60/0.08, 1.80/0.03,
61         2.00/0.15, 2.20/0.24, 2.40/0.70, 2.60/0.05, 2.80/0.03,
62         3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00 }{
63         \uncover<2>{\draw[black,fill=red] (\x, 0) rectangle ++(0.2, \y);}
64         \uncover<3>{\draw[black,fill=yellow] (\x, 0) rectangle ++(0.2, \y);}
65       };
66     \end{tikzpicture}
67   };
68   }
69
70   \uncover<4-5>{
71   \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (a2) at (-0.7, 2.1) {
72     \begin{tikzpicture}
73       \draw[draw=none] (0, 0) rectangle (4, 1);
74       \foreach \x/\y in {
75         0.00/0.03, 0.20/0.04, 0.40/0.07, 0.60/0.03, 0.80/0.03,
76         1.00/0.05, 1.20/0.02, 1.40/0.08, 1.60/0.35, 1.80/0.85,
77         2.00/0.05, 2.20/0.04, 2.40/0.03, 2.60/0.05, 2.80/0.03,
78         3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00 }{
79         \uncover<4>{\draw[black,fill=red] (\x, 0) rectangle ++(0.2, \y);}
80         \uncover<5>{\draw[black,fill=yellow] (\x, 0) rectangle ++(0.2, \y);}
81       };
82     \end{tikzpicture}
83   };
84   }
85
86   \node[cm={1.0, 0.0, 0.0, 1.0, (0.0, 0.0)}] (Q) at (-0.5, -0.05) {
87     \begin{tikzpicture}
88       \draw[fill=green!20] (0, 0) rectangle (3, 1.0);
89       \foreach \x in { 0.2, 0.4, ..., 2.8 } \draw (\x, 0) -- ++(0, 1.0);
90       \uncover<2>{\draw[fill=yellow] (0.0, 0) rectangle ++(0.2, 1);}
91       \uncover<4>{\draw[fill=yellow] (0.2, 0) rectangle ++(0.2, 1);}
92     \end{tikzpicture}
93     };
94
95   \node[cm={1.0, 0.0, 0.0, 1.0, (0.0, 0.0)}] (Y) at (1.5, 3.45) {
96     \begin{tikzpicture}
97       \uncover<3>{\draw[fill=red] (0.0, 0) rectangle ++(0.2, 1.4);}
98       \uncover<4->{\draw[fill=green!20] (0.0, 0) rectangle ++(0.2, 1.4);}
99       \uncover<6->{\draw[fill=green!20] (0.0, 0) rectangle ++(3, 1.4);}
100       \uncover<5>{\draw[fill=red] (0.2, 0) rectangle ++(0.2, 1.4);}
101       \draw (0, 0) rectangle (3, 1.4);
102       \foreach \x in { 0.2, 0.4, ..., 2.8 } \draw (\x, 0) -- ++(0, 1.4);
103     \end{tikzpicture}
104     };
105
106   \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (K) at (3, 1.1) {
107     \begin{tikzpicture}
108       \draw[fill=green!20] (0, 0) rectangle (4, 1);
109       \uncover<2,4>{\draw[fill=yellow] (0, 0) rectangle (4, 1);}
110       \foreach \x in { 0.2, 0.4, ..., 3.8 } \draw (\x, 0) -- ++(0, 1);
111     \end{tikzpicture}
112   };
113
114   \node[left of=V,xshift=0.5cm,yshift=0.7cm] (Vl) {V};
115   \node[left of=Q,xshift=-0.8cm] (Ql) {Q};
116   \node (Al) at (A) {A};
117   \node[right of=K,xshift=-0.6cm,yshift=-0.6cm] (Kl) {K};
118   \node[right of=Y,xshift=0.8cm] (Yl) {Y};
119
120 \end{tikzpicture}
121 }
122
123 \end{frame}
124
125 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
126
127 \begin{frame}[fragile]
128
129 A standard attention layer takes as input two sequences $X$ and $X'$
130 and computes
131 %
132 \begin{align*}
133 K & = W^K X \\
134 V & = W^V X \\
135 Q & = w^Q X' \\
136 Y & = \underbrace{\softmax_{row} \left( \frac{Q K\transpose}{\sqrt{d}} \right)}_{A} V
137 \end{align*}
138
139 When $X = X'$, this is \textbf{self attention}, otherwise \textbf{cross
140   attention.}
141
142 \pause
143
144 \bigskip
145
146 Several such processes can be combined in which case $Y$ is the
147 concatenation of the separate results. This is \textbf{multi-head
148   attention}.
149
150 \end{frame}
151
152
153 \end{document}