Update.
[tex.git] / single-attention.tex
diff --git a/single-attention.tex b/single-attention.tex
new file mode 100644 (file)
index 0000000..01a181c
--- /dev/null
@@ -0,0 +1,116 @@
+% -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; mode: yas/minor; coding: utf-8; tex-command: "pdflatex.sh" -*-
+
+% Any copyright is dedicated to the Public Domain.
+% https://creativecommons.org/publicdomain/zero/1.0/
+
+% Written by Francois Fleuret <francois@fleuret.org>
+
+\documentclass[c,8pt]{beamer}
+
+\setbeamertemplate{navigation symbols}{}
+
+\def\transpose{^{\top}}
+\def\softmax{\operatorname{softmax}}
+
+\definecolor{blue}{rgb}{0.0,0.0,0.55}
+\definecolor{green}{rgb}{0.0,0.50,0.0}
+\definecolor{bluegray}{rgb}{0.1,0.2,0.7}
+
+\setbeamercolor{math text}{fg=bluegray}
+\setbeamercolor{local structure}{fg=blue}
+
+\usepackage{tikz}
+
+\usetikzlibrary{positioning,fit,backgrounds}
+\usetikzlibrary{arrows.meta,decorations.pathreplacing}
+\usetikzlibrary{calc}
+\usetikzlibrary{shapes,calc,intersections}
+\usetikzlibrary{patterns}
+
+\usetikzlibrary{arrows}
+
+\definecolor{nn-data}   {rgb}{0.90, 0.95, 1.00}
+\definecolor{nn-param}  {rgb}{1.00, 0.90, 0.50}
+\definecolor{nn-process}{rgb}{0.80, 1.00, 0.80}
+\tikzset{>={Straight Barb[angle'=80,scale=1.1]}}
+
+\tikzset{
+  value/.style    ={ font=\scriptsize, rectangle, draw=black!50, fill=white,   thick,
+                     inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt },
+  parameter/.style={ font=\scriptsize, rectangle, draw=black!50, fill=blue!15, thick,
+                     inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt },
+  operation/.style={ font=\scriptsize, rectangle,    draw=black!50, fill=green!30, thick,
+                     inner sep=3pt, minimum size=10pt, minimum height=20pt },
+  flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick},
+%
+  f2f/.style={draw=black!50, thick},
+  v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick},
+  f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick},
+  v2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black!50, thick},
+%
+%
+  df2f/.style={draw=black, thick},
+  dv2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black, thick},
+  df2v/.style={->,shorten >= 0.75pt,draw=black, thick},
+  dv2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black, thick},
+%
+  differential/.style    ={ font=\small, rectangle, draw=black!50,               thick,
+                     inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt, fill=yellow!80 },
+  dflow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black, thick}
+}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\begin{document}
+
+\begin{frame}
+
+\begin{center}
+
+\begin{tikzpicture}
+
+\node[value,    minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
+\node[value,    minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$};
+\node[value,    minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$};
+\node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
+\node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
+\node[value,    minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
+\node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
+\node[value,    minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
+
+\draw[v2f,rounded corners=1mm] (K) -- (att);
+\draw[v2f,rounded corners=1mm] (Q) -| (att);
+\draw[f2f,rounded corners=1mm] (att) -- (sm);
+\draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
+
+\draw[v2f,rounded corners=1mm] (A) -- (prod);
+\draw[v2f,rounded corners=1mm] (V) -| (prod);
+\draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
+
+\draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
+\draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
+\draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
+\draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
+
+\draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
+\draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
+\draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
+\draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
+
+\draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
+\draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
+
+\end{tikzpicture}
+
+\begin{align*}
+A & = \softmax_{row} \left( \frac{Q K\transpose}{\sqrt{D}} \right) \\
+Y & = A V.
+\end{align*}
+
+Single-head attention
+
+\end{center}
+
+\end{frame}
+
+\end{document}