From 6e541e7102264b99f1a4aa72325a2b4b81fcb3eb Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 26 May 2022 17:24:55 +0200 Subject: [PATCH] Update. --- attention.tex | 49 +++++++++++------- single-attention.tex | 116 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 18 deletions(-) create mode 100644 single-attention.tex diff --git a/attention.tex b/attention.tex index b6f15dd..276759a 100644 --- a/attention.tex +++ b/attention.tex @@ -1,5 +1,7 @@ % -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; -*- +% Written by Francois Fleuret + \documentclass[c,8pt]{beamer} \usepackage{tikz} @@ -22,6 +24,7 @@ weight $V$ with it to get $Y$. \[ \uncover<2,4,6->{ +% A_{i,j} = \softmax \left( \frac{Q_i \cdot K_j}{\sqrt{d}} \right) A_i = \softmax \left( \frac{Q_i \, K\transpose}{\sqrt{d}} \right) } % @@ -37,30 +40,34 @@ weight $V$ with it to get $Y$. \makebox[\textwidth][c]{ \begin{tikzpicture} - \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (V) at (-2, 2.35) { + \node[xscale=0.5,yslant=0.5] (V) at (-2, 2.35) { \begin{tikzpicture} \draw[fill=green!20] (0, 0) rectangle (4, 1.4); \uncover<3,5>{\draw[fill=yellow] (0, 0) rectangle (4, 1.4);} \foreach \x in { 0.2, 0.4, ..., 3.8 } \draw (\x, 0) -- ++(0, 1.4); + %% \foreach \y in { 0.0, 0.2, ..., 1.4 } \draw (0, \y) -- ++(4, 0); \end{tikzpicture} }; - \node[cm={1.0, 0.0, 0.5, 0.5, (0.0, 0.0)}] (A) at (0.5, 1.6) { + \node[yscale=0.5,xslant=0.5] (A) at (0.5, 1.6) { \begin{tikzpicture} \draw (0, 0) rectangle ++(3, 4); + %% \uncover<4->{\draw[fill=green!20] (0, 0) rectangle ++(0.2, 4);} + %% \uncover<6->{\draw[fill=green!20] (0.2, 0) rectangle ++(0.2, 4);} \end{tikzpicture} }; \uncover<2-3>{ - \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (a1) at (-0.9, 2.1) { + \node[xscale=0.5,yslant=0.5] (a1) at (-0.9, 2.1) { \begin{tikzpicture} \draw[draw=none] (0, 0) rectangle (4, 1); \foreach \x/\y in { 0.00/0.03, 0.20/0.04, 0.40/0.07, 0.60/0.35, 0.80/0.52, 1.00/1.00, 1.20/0.82, 1.40/0.25, 1.60/0.08, 1.80/0.03, 2.00/0.15, 2.20/0.24, 2.40/0.70, 2.60/0.05, 2.80/0.03, - 3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00 }{ - \uncover<2>{\draw[black,fill=red] (\x, 0) rectangle ++(0.2, \y);} + 3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00 + }{ + \uncover<2>{\draw[black,fill=orange] (\x, 0) rectangle ++(0.2, \y);} \uncover<3>{\draw[black,fill=yellow] (\x, 0) rectangle ++(0.2, \y);} }; \end{tikzpicture} @@ -68,54 +75,60 @@ weight $V$ with it to get $Y$. } \uncover<4-5>{ - \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (a2) at (-0.7, 2.1) { + \node[xscale=0.5,yslant=0.5] (a2) at (-0.7, 2.1) { \begin{tikzpicture} \draw[draw=none] (0, 0) rectangle (4, 1); \foreach \x/\y in { 0.00/0.03, 0.20/0.04, 0.40/0.07, 0.60/0.03, 0.80/0.03, 1.00/0.05, 1.20/0.02, 1.40/0.08, 1.60/0.35, 1.80/0.85, 2.00/0.05, 2.20/0.04, 2.40/0.03, 2.60/0.05, 2.80/0.03, - 3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00 }{ - \uncover<4>{\draw[black,fill=red] (\x, 0) rectangle ++(0.2, \y);} + 3.00/0.03, 3.20/0.03, 3.40/0.00, 3.60/0.03, 3.80/0.00 + }{ + \uncover<4>{\draw[black,fill=orange] (\x, 0) rectangle ++(0.2, \y);} \uncover<5>{\draw[black,fill=yellow] (\x, 0) rectangle ++(0.2, \y);} }; \end{tikzpicture} }; } - \node[cm={1.0, 0.0, 0.0, 1.0, (0.0, 0.0)}] (Q) at (-0.5, -0.05) { + \node (Q) at (-0.5, -0.05) { \begin{tikzpicture} \draw[fill=green!20] (0, 0) rectangle (3, 1.0); \foreach \x in { 0.2, 0.4, ..., 2.8 } \draw (\x, 0) -- ++(0, 1.0); \uncover<2>{\draw[fill=yellow] (0.0, 0) rectangle ++(0.2, 1);} \uncover<4>{\draw[fill=yellow] (0.2, 0) rectangle ++(0.2, 1);} + %% \foreach \y in { 0.0, 0.2, ..., 1.0 } \draw (0, \y) -- ++(3, 0); \end{tikzpicture} }; - \node[cm={1.0, 0.0, 0.0, 1.0, (0.0, 0.0)}] (Y) at (1.5, 3.45) { + \node (Y) at (1.5, 3.45) { \begin{tikzpicture} - \uncover<3>{\draw[fill=red] (0.0, 0) rectangle ++(0.2, 1.4);} + \uncover<3>{\draw[fill=orange] (0.0, 0) rectangle ++(0.2, 1.4);} \uncover<4->{\draw[fill=green!20] (0.0, 0) rectangle ++(0.2, 1.4);} \uncover<6->{\draw[fill=green!20] (0.0, 0) rectangle ++(3, 1.4);} - \uncover<5>{\draw[fill=red] (0.2, 0) rectangle ++(0.2, 1.4);} + \uncover<5>{\draw[fill=orange] (0.2, 0) rectangle ++(0.2, 1.4);} \draw (0, 0) rectangle (3, 1.4); \foreach \x in { 0.2, 0.4, ..., 2.8 } \draw (\x, 0) -- ++(0, 1.4); + %% \foreach \y in { 0.0, 0.2, ..., 1.4 } \draw (0, \y) -- ++(3, 0); \end{tikzpicture} }; - \node[cm={0.5, 0.5, 0.0, 1.0, (0.0, 0.0)}] (K) at (3, 1.1) { + \node[xscale=0.5,yslant=0.5] (K) at (3, 1.1) { \begin{tikzpicture} \draw[fill=green!20] (0, 0) rectangle (4, 1); \uncover<2,4>{\draw[fill=yellow] (0, 0) rectangle (4, 1);} \foreach \x in { 0.2, 0.4, ..., 3.8 } \draw (\x, 0) -- ++(0, 1); + %% \foreach \y in { 0.0, 0.2, ..., 1.0 } \draw (0, \y) -- ++(4, 0); \end{tikzpicture} }; - \node[left of=V,xshift=0.5cm,yshift=0.7cm] (Vl) {V}; - \node[left of=Q,xshift=-0.8cm] (Ql) {Q}; - \node (Al) at (A) {A}; - \node[right of=K,xshift=-0.6cm,yshift=-0.6cm] (Kl) {K}; - \node[right of=Y,xshift=0.8cm] (Yl) {Y}; + \node[left of=V,xshift=0.5cm,yshift=0.7cm] (Vl) {$V$}; + \node[left of=Q,xshift=-0.8cm] (Ql) {$Q$}; + \node (Al) at (A) {$A$}; + \node[right of=K,xshift=-0.6cm,yshift=-0.6cm] (Kl) {$K$}; + \node[right of=Y,xshift=0.8cm] (Yl) {$Y$}; + + % \uncover<1>{\draw[<->] (2, 0) -- ++ (0, 1) node[midway,right]{$d$};} \end{tikzpicture} } diff --git a/single-attention.tex b/single-attention.tex new file mode 100644 index 0000000..01a181c --- /dev/null +++ b/single-attention.tex @@ -0,0 +1,116 @@ +% -*- mode: latex; mode: reftex; mode: auto-fill; mode: flyspell; mode: yas/minor; coding: utf-8; tex-command: "pdflatex.sh" -*- + +% Any copyright is dedicated to the Public Domain. +% https://creativecommons.org/publicdomain/zero/1.0/ + +% Written by Francois Fleuret + +\documentclass[c,8pt]{beamer} + +\setbeamertemplate{navigation symbols}{} + +\def\transpose{^{\top}} +\def\softmax{\operatorname{softmax}} + +\definecolor{blue}{rgb}{0.0,0.0,0.55} +\definecolor{green}{rgb}{0.0,0.50,0.0} +\definecolor{bluegray}{rgb}{0.1,0.2,0.7} + +\setbeamercolor{math text}{fg=bluegray} +\setbeamercolor{local structure}{fg=blue} + +\usepackage{tikz} + +\usetikzlibrary{positioning,fit,backgrounds} +\usetikzlibrary{arrows.meta,decorations.pathreplacing} +\usetikzlibrary{calc} +\usetikzlibrary{shapes,calc,intersections} +\usetikzlibrary{patterns} + +\usetikzlibrary{arrows} + +\definecolor{nn-data} {rgb}{0.90, 0.95, 1.00} +\definecolor{nn-param} {rgb}{1.00, 0.90, 0.50} +\definecolor{nn-process}{rgb}{0.80, 1.00, 0.80} +\tikzset{>={Straight Barb[angle'=80,scale=1.1]}} + +\tikzset{ + value/.style ={ font=\scriptsize, rectangle, draw=black!50, fill=white, thick, + inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt }, + parameter/.style={ font=\scriptsize, rectangle, draw=black!50, fill=blue!15, thick, + inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt }, + operation/.style={ font=\scriptsize, rectangle, draw=black!50, fill=green!30, thick, + inner sep=3pt, minimum size=10pt, minimum height=20pt }, + flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick}, +% + f2f/.style={draw=black!50, thick}, + v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick}, + f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick}, + v2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black!50, thick}, +% +% + df2f/.style={draw=black, thick}, + dv2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black, thick}, + df2v/.style={->,shorten >= 0.75pt,draw=black, thick}, + dv2v/.style={{Bar[width=1.5mm]}->,shorten <= 0.75pt,shorten >= 0.5pt,draw=black, thick}, +% + differential/.style ={ font=\small, rectangle, draw=black!50, thick, + inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt, fill=yellow!80 }, + dflow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black, thick} +} + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +\begin{document} + +\begin{frame} + +\begin{center} + +\begin{tikzpicture} + +\node[value, minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$}; +\node[value, minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$}; +\node[value, minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$}; +\node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$}; +\node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$}; +\node[value, minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$}; +\node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$}; +\node[value, minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$}; + +\draw[v2f,rounded corners=1mm] (K) -- (att); +\draw[v2f,rounded corners=1mm] (Q) -| (att); +\draw[f2f,rounded corners=1mm] (att) -- (sm); +\draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west); + +\draw[v2f,rounded corners=1mm] (A) -- (prod); +\draw[v2f,rounded corners=1mm] (V) -| (prod); +\draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west); + +\draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east); +\draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east); +\draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east); +\draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east); + +\draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west); +\draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west); +\draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west); +\draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west); + +\draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west); +\draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east); + +\end{tikzpicture} + +\begin{align*} +A & = \softmax_{row} \left( \frac{Q K\transpose}{\sqrt{D}} \right) \\ +Y & = A V. +\end{align*} + +Single-head attention + +\end{center} + +\end{frame} + +\end{document} -- 2.20.1