Initial commit
authorFrançois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2023 14:45:41 +0000 (16:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2023 14:45:41 +0000 (16:45 +0200)
warp.py [new file with mode: 0755]
warp.tex [new file with mode: 0644]

diff --git a/warp.py b/warp.py
new file mode 100755 (executable)
index 0000000..96dfa11
--- /dev/null
+++ b/warp.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python
+
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, argparse, os
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("--result_dir", type=str, default="/tmp")
+
+args = parser.parse_args()
+
+######################################################################
+
+# If the source is older than the result, do nothing
+
+ref_filename = os.path.join(args.result_dir, f"warp_0.tex")
+
+if os.path.exists(ref_filename) and os.path.getmtime(__file__) < os.path.getmtime(
+    ref_filename
+):
+    exit(0)
+
+######################################################################
+
+torch.manual_seed(0)
+
+nb = 1000
+x = torch.rand(nb, 2) * torch.tensor([math.pi * 1.5, 0.10]) + torch.tensor(
+    [math.pi * -0.25, 0.25]
+)
+
+train_targets = (torch.rand(nb) < 0.5).long()
+train_input = torch.cat((x[:, 0:1].sin() * x[:, 1:2], x[:, 0:1].cos() * x[:, 1:2]), 1)
+train_input[:, 0] *= train_targets * 2 - 1
+train_input[:, 0] += 0.05 * (train_targets * 2 - 1)
+train_input[:, 1] -= 0.15 * (train_targets * 2 - 1)
+train_input *= 1.2
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return 0.5 * x + 0.5 * self.f(x)
+
+
+model = nn.Sequential(
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
+    nn.Linear(2, 2),
+)
+
+with torch.no_grad():
+    for p in model.modules():
+        if isinstance(p, nn.Linear):
+            # p.bias.zero_()
+            p.weight[...] = 2 * torch.eye(2) + torch.randn(2, 2) * 1e-4
+
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+criterion = nn.CrossEntropyLoss()
+
+nb_epochs, batch_size = 1000, 25
+
+for k in range(nb_epochs):
+    acc_loss = 0.0
+
+    for input, targets in zip(
+        train_input.split(batch_size), train_targets.split(batch_size)
+    ):
+        output = model(input)
+        loss = criterion(output, targets)
+        acc_loss += loss.item()
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+    nb_train_errors = 0
+    for input, targets in zip(
+        train_input.split(batch_size), train_targets.split(batch_size)
+    ):
+        wta = model(input).argmax(1)
+        nb_train_errors += (wta != targets).long().sum()
+    train_error = nb_train_errors / train_input.size(0)
+
+    print(f"loss {k} {acc_loss:.02f} {train_error*100:.02f}%")
+
+    if train_error == 0:
+        break
+
+######################################################################
+
+sg=25
+
+input, targets = train_input, train_targets
+
+grid = torch.linspace(-1.2,1.2,sg)
+grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2)
+
+for l, m in enumerate(model):
+    with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
+        f.write(
+            """\\addplot[
+    scatter src=explicit symbolic,
+    scatter/classes={0={blue}, 1={red}},
+    scatter, mark=*, only marks, mark options={mark size=0.5},
+]%
+table[meta=label] {
+x y label
+"""
+        )
+        for k in range(512):
+            f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
+        f.write("};\n")
+
+        g = grid.reshape(sg,sg,-1)
+        for i in range(g.size(0)):
+            for j in range(g.size(1)):
+                if j == 0:
+                    pre="\\draw[black!25,very thin] "
+                else:
+                    pre="--"
+                f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
+            f.write(";\n")
+
+        for j in range(g.size(1)):
+            for i in range(g.size(0)):
+                if i == 0:
+                    pre="\\draw[black!25,very thin] "
+                else:
+                    pre="--"
+                f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
+            f.write(";\n")
+
+        # add the decision line
+
+        if l == len(model) - 1:
+            u = torch.tensor([[1.0, -1.0]])
+            phi = model[-1]
+            a, b = (u @ phi.weight).squeeze(), (u @ phi.bias).item()
+            p = a * (b / (a @ a.t()).item())
+            f.write(
+                f"\\draw[black,thick] ({p[0]-a[1]},{p[1]+a[0]}) -- ({p[0]+a[1]},{p[1]-a[0]});"
+            )
+
+    input, grid = m(input), m(grid)
+
+######################################################################
diff --git a/warp.tex b/warp.tex
new file mode 100644 (file)
index 0000000..d03b295
--- /dev/null
+++ b/warp.tex
@@ -0,0 +1,66 @@
+%% -*- mode: latex; mode: reftex; mode: flyspell; coding: utf-8; tex-command: "pdflatex.sh" -*-
+
+\documentclass[11pt,a4paper,twoside]{article}
+\usepackage[a4paper,top=2.5cm,bottom=2cm,left=2.5cm,right=2.5cm]{geometry}
+\usepackage[colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue]{hyperref}
+\usepackage{amsmath}
+\usepackage{amssymb}
+\usepackage{dsfont}
+\usepackage{tikz}
+\usetikzlibrary{arrows,arrows.meta,calc}
+\usetikzlibrary{patterns,backgrounds}
+\usetikzlibrary{positioning,fit}
+\usetikzlibrary{shapes.geometric,shapes.multipart}
+\usetikzlibrary{patterns.meta,decorations.pathreplacing,calligraphy}
+\usetikzlibrary{tikzmark}
+\usetikzlibrary{decorations.pathmorphing}
+\usepackage{pgfplots}
+\usepgfplotslibrary{patchplots,colormaps}
+\pgfplotsset{compat = newest}
+
+
+\begin{document}
+
+\definecolor{blue}{rgb}{0.3,0.5,0.85}
+\definecolor{red}{rgb}{0.65,0.0,0.0}
+
+\begin{figure}
+
+  \immediate\write18{./warp.py --result_dir=.}
+
+  \newcommand{\warp}[1]{%
+    \begin{tikzpicture}
+      \begin{axis}[ticks=none,width=7.0cm, height=7.0cm,xmin=-1.2,xmax=1.2,ymin=-1.2,ymax=1.2]
+        \input{#1}
+      \end{axis}
+    \end{tikzpicture}
+  }
+
+  \center
+
+  \begin{tikzpicture}[warp/.style={inner sep=1pt,minimum width=5.0cm,minimum height=5.0cm}]
+    \node[warp]                 (W0) {\warp{warp_0.tex}};
+    \node[warp,right=2pt of W0] (W1) {\warp{warp_1.tex}};
+    \node[warp,right=2pt of W1] (W2) {\warp{warp_2.tex}};
+    \node[warp,below=20pt of W0] (W3) {\warp{warp_3.tex}};
+    \node[warp,right=2pt of W3] (W4) {\warp{warp_4.tex}};
+    \node[warp,right=2pt of W4] (W5) {\warp{warp_5.tex}};
+    \node[warp,below=20pt of W3] (W6) {\warp{warp_6.tex}};
+    \node[warp,right=2pt of W6] (W7) {\warp{warp_7.tex}};
+    \node[warp,right=2pt of W7] (W8) {\warp{warp_8.tex}};
+    \node[inner sep=0pt,below=4pt of W0] (lW0) {\footnotesize Input};
+    \foreach \n in {1,...,8}{
+      \node[inner sep=0pt,below=4pt of W\n] (lW\n) {\footnotesize Layer \#\n};
+    };
+
+  \end{tikzpicture}
+
+  \caption[Feature warping]{Each plot shows the deformation of the space
+    and the resulting distribution of the training points in
+    $\mathbb{R}^2$ corresponding to the output of each layer, starting
+    with the input in the top-left square. The thick oblique line in the
+    bottom-right plot shows the final affine decision.}\label{fig:warp}
+
+\end{figure}
+
+\end{document}