Update.
[tex.git] / attention.tex
1 % -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; -*-
2
3 % Any copyright is dedicated to the Public Domain.
4 % https://creativecommons.org/publicdomain/zero/1.0/
5
6 % Written by Francois Fleuret <francois@fleuret.org>
7
8 \documentclass[c,8pt]{beamer}
9
10 \usepackage{tikz}
11 \newcommand{\transpose}{^{\top}}
12 \def\softmax{\operatorname{softmax}}
13
14 \setbeamertemplate{navigation symbols}{}
15
16 \begin{document}
17
18 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
19
20 \begin{frame}[fragile]
21
22 Given a query sequence $Q$, a key sequence $K$, and a value sequence
23 $V$, compute an attention matrix $A$ by matching $Q$s to $K$s, and
24 weight $V$ with it to get $Y$.
25
26 \medskip
27
28 \[
29 \uncover<2,4,6->{
30 %  A_{i,j} = \softmax \left( \frac{Q_i \cdot K_j}{\sqrt{d}} \right)
31   A_i = \softmax \left( \frac{Q_i \, K\transpose}{\sqrt{d}} \right)
32 }
33 %
34 \quad \quad \quad
35 %
36 \uncover<3,5->{
37   Y_i = A_i V
38 }
39 \]
40
41 \medskip
42
43 \makebox[\textwidth][c]{
44 \begin{tikzpicture}
45
46   \node[xscale=0.5,yslant=0.5] (V) at (-2, 2.35) {
47     \begin{tikzpicture}
48       \draw[fill=green!20] (0, 0) rectangle (4, 1.4);
49       \uncover<3,5>{\draw[fill=yellow] (0, 0) rectangle (4, 1.4);}
50       \foreach \x in { 0.2, 0.4, ..., 3.8 } \draw (\x, 0) -- ++(0, 1.4);
51       %% \foreach \y in { 0.0, 0.2, ..., 1.4 } \draw (0, \y) -- ++(4, 0);
52     \end{tikzpicture}
53   };
54
55   \node[yscale=0.5,xslant=0.5] (A) at (0.5, 1.6) {
56     \begin{tikzpicture}
57       \draw (0, 0) rectangle ++(3, 4);
58       %% \uncover<4->{\draw[fill=green!20] (0, 0) rectangle ++(0.2, 4);}
59       %% \uncover<6->{\draw[fill=green!20] (0.2, 0) rectangle ++(0.2, 4);}
60     \end{tikzpicture}
61   };
62
63   \uncover<2-3>{
64   \node[xscale=0.5,yslant=0.5] (a1) at (-0.9, 2.1) {
65     \begin{tikzpicture}
66       \draw[draw=none] (0, 0) rectangle (4, 1);
67       \foreach \x/\y in {
68         0.00/0.03, 0.20/0.04, 0.40/0.07, 0.60/0.35, 0.80/0.52,
69         1.00/1.00, 1.20/0.82, 1.40/0.25, 1.60/0.08, 1.80/0.03,
70         2.00/0.15, 2.20/0.24, 2.40/0.70, 2.60/0.05, 2.80/0.03,
71         3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00
72       }{
73         \uncover<2>{\draw[black,fill=orange] (\x, 0) rectangle ++(0.2, \y);}
74         \uncover<3>{\draw[black,fill=yellow] (\x, 0) rectangle ++(0.2, \y);}
75       };
76     \end{tikzpicture}
77   };
78   }
79
80   \uncover<4-5>{
81   \node[xscale=0.5,yslant=0.5] (a2) at (-0.7, 2.1) {
82     \begin{tikzpicture}
83       \draw[draw=none] (0, 0) rectangle (4, 1);
84       \foreach \x/\y in {
85         0.00/0.03, 0.20/0.04, 0.40/0.07, 0.60/0.03, 0.80/0.03,
86         1.00/0.05, 1.20/0.02, 1.40/0.08, 1.60/0.35, 1.80/0.85,
87         2.00/0.05, 2.20/0.04, 2.40/0.03, 2.60/0.05, 2.80/0.03,
88         3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00
89       }{
90         \uncover<4>{\draw[black,fill=orange] (\x, 0) rectangle ++(0.2, \y);}
91         \uncover<5>{\draw[black,fill=yellow] (\x, 0) rectangle ++(0.2, \y);}
92       };
93     \end{tikzpicture}
94   };
95   }
96
97   \node (Q) at (-0.5, -0.05) {
98     \begin{tikzpicture}
99       \draw[fill=green!20] (0, 0) rectangle (3, 1.0);
100       \foreach \x in { 0.2, 0.4, ..., 2.8 } \draw (\x, 0) -- ++(0, 1.0);
101       \uncover<2>{\draw[fill=yellow] (0.0, 0) rectangle ++(0.2, 1);}
102       \uncover<4>{\draw[fill=yellow] (0.2, 0) rectangle ++(0.2, 1);}
103       %% \foreach \y in { 0.0, 0.2, ..., 1.0 } \draw (0, \y) -- ++(3, 0);
104     \end{tikzpicture}
105     };
106
107   \node (Y) at (1.5, 3.45) {
108     \begin{tikzpicture}
109       \uncover<3>{\draw[fill=orange] (0.0, 0) rectangle ++(0.2, 1.4);}
110       \uncover<4->{\draw[fill=green!20] (0.0, 0) rectangle ++(0.2, 1.4);}
111       \uncover<6->{\draw[fill=green!20] (0.0, 0) rectangle ++(3, 1.4);}
112       \uncover<5>{\draw[fill=orange] (0.2, 0) rectangle ++(0.2, 1.4);}
113       \draw (0, 0) rectangle (3, 1.4);
114       \foreach \x in { 0.2, 0.4, ..., 2.8 } \draw (\x, 0) -- ++(0, 1.4);
115       %% \foreach \y in { 0.0, 0.2, ..., 1.4 } \draw (0, \y) -- ++(3, 0);
116     \end{tikzpicture}
117     };
118
119   \node[xscale=0.5,yslant=0.5] (K) at (3, 1.1) {
120     \begin{tikzpicture}
121       \draw[fill=green!20] (0, 0) rectangle (4, 1);
122       \uncover<2,4>{\draw[fill=yellow] (0, 0) rectangle (4, 1);}
123       \foreach \x in { 0.2, 0.4, ..., 3.8 } \draw (\x, 0) -- ++(0, 1);
124       %% \foreach \y in { 0.0, 0.2, ..., 1.0 } \draw (0, \y) -- ++(4, 0);
125     \end{tikzpicture}
126   };
127
128   \node[left of=V,xshift=0.5cm,yshift=0.7cm] (Vl) {$V$};
129   \node[left of=Q,xshift=-0.8cm] (Ql) {$Q$};
130   \node (Al) at (A) {$A$};
131   \node[right of=K,xshift=-0.6cm,yshift=-0.6cm] (Kl) {$K$};
132   \node[right of=Y,xshift=0.8cm] (Yl) {$Y$};
133
134   %  \uncover<1>{\draw[<->] (2, 0) -- ++ (0, 1) node[midway,right]{$d$};}
135
136 \end{tikzpicture}
137 }
138
139 \end{frame}
140
141 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
142
143 \begin{frame}[fragile]
144
145 A standard attention layer takes as input two sequences $X$ and $X'$
146 and computes
147 %
148 \begin{align*}
149 K & = W^K X \\
150 V & = W^V X \\
151 Q & = w^Q X' \\
152 Y & = \underbrace{\softmax_{row} \left( \frac{Q K\transpose}{\sqrt{d}} \right)}_{A} V
153 \end{align*}
154
155 When $X = X'$, this is \textbf{self attention}, otherwise \textbf{cross
156   attention.}
157
158 \pause
159
160 \bigskip
161
162 Several such processes can be combined in which case $Y$ is the
163 concatenation of the separate results. This is \textbf{multi-head
164   attention}.
165
166 \end{frame}
167
168
169 \end{document}