Update.
[tex.git] / single-attention.tex
index 01a181c..d21595e 100644 (file)
@@ -113,4 +113,70 @@ Single-head attention
 
 \end{frame}
 
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\begin{frame}
+
+\begin{center}
+
+\begin{tikzpicture}
+
+\node[value,    minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
+\draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
+\draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
+
+\node[value,    minimum height=1.2cm,minimum width=0.7cm] (Q) [above=1cm of K] {$Q$};
+\draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
+\draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
+
+\node[value,    minimum height=0.8cm,minimum width=1.0cm] (V) [below=1cm of K] {$V$};
+\draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
+\draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
+
+\node[operation,minimum height=0.4cm] (mulWq) [left=1cm of Q.center] {$\cdot$};
+\node[operation,minimum height=0.4cm] (mulWk) [left=1cm of K.center] {$\cdot$};
+\node[operation,minimum height=0.4cm] (mulWv) [left=1cm of V.center] {$\cdot$};
+
+\node[value,    minimum height=1.2cm,minimum width=0.5cm] (X) [left=1cm of mulWq] {$X$};
+\draw[very thick,cyan] ([xshift=-1pt]X.north west) -- ([xshift=-1pt]X.south west);
+\draw[very thick,green] ([yshift=1pt]X.north west) -- ([yshift=1pt]X.north east);
+
+\node[parameter,    minimum height=0.5cm,minimum width=0.7cm] (Wq) [above=0.25 cm of X] {$W^Q$};
+\draw[very thick,green] ([xshift=-1pt]Wq.north west) -- ([xshift=-1pt]Wq.south west);
+\draw[very thick,yellow] ([yshift=1pt]Wq.north west) -- ([yshift=1pt]Wq.north east);
+
+\node[value,    minimum height=0.8cm,minimum width=0.3cm] (X') [below=1.2cm of X]  {$X'$};
+
+\node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
+\node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
+
+\node[value,    minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
+\draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
+\draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
+
+\node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
+
+\node[value,    minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
+\draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
+\draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
+
+\draw[v2f,rounded corners=1mm] (X) -- (mulWq);
+\draw[v2f,rounded corners=1mm] (Wq) -| (mulWq);
+\draw[f2v,rounded corners=1mm] (mulWq) -- ([xshift=-1pt]Q.west);
+
+\draw[v2f,rounded corners=1mm] (K) -- (att);
+\draw[v2f,rounded corners=1mm] (Q) -| (att);
+\draw[f2f,rounded corners=1mm] (att) -- (sm);
+\draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
+
+\draw[v2f,rounded corners=1mm] (A) -- (prod);
+\draw[v2f,rounded corners=1mm] (V) -| (prod);
+\draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
+
+\end{tikzpicture}
+
+\end{center}
+
+\end{frame}
+
 \end{document}