Update.
[tex.git] / warp.py
1 #!/usr/bin/env python
2
3
4 # Any copyright is dedicated to the Public Domain.
5 # https://creativecommons.org/publicdomain/zero/1.0/
6
7 # Written by Francois Fleuret <francois@fleuret.org>
8
9 import math, argparse, os
10
11 import torch, torchvision
12
13 from torch import nn
14 from torch.nn import functional as F
15
16 ######################################################################
17
18 parser = argparse.ArgumentParser()
19
20 parser.add_argument("--result_dir", type=str, default="/tmp")
21
22 args = parser.parse_args()
23
24 ######################################################################
25
26 # If the source is older than the result, do nothing
27
28 ref_filename = os.path.join(args.result_dir, f"warp_0.tex")
29
30 if os.path.exists(ref_filename) and os.path.getmtime(__file__) < os.path.getmtime(
31     ref_filename
32 ):
33     exit(0)
34
35 ######################################################################
36
37 torch.manual_seed(0)
38
39 nb = 1000
40 x = torch.rand(nb, 2) * torch.tensor([math.pi * 1.5, 0.10]) + torch.tensor(
41     [math.pi * -0.25, 0.25]
42 )
43
44 train_targets = (torch.rand(nb) < 0.5).long()
45 train_input = torch.cat((x[:, 0:1].sin() * x[:, 1:2], x[:, 0:1].cos() * x[:, 1:2]), 1)
46 train_input[:, 0] *= train_targets * 2 - 1
47 train_input[:, 0] += 0.05 * (train_targets * 2 - 1)
48 train_input[:, 1] -= 0.15 * (train_targets * 2 - 1)
49 train_input *= 1.2
50
51
52 class WithResidual(nn.Module):
53     def __init__(self, *f):
54         super().__init__()
55         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
56
57     def forward(self, x):
58         return 0.5 * x + 0.5 * self.f(x)
59
60
61 model = nn.Sequential(
62     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
63     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
64     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
65     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
66     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
67     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
68     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
69     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
70     nn.Linear(2, 2),
71 )
72
73 with torch.no_grad():
74     for p in model.modules():
75         if isinstance(p, nn.Linear):
76             # p.bias.zero_()
77             p.weight[...] = 2 * torch.eye(2) + torch.randn(2, 2) * 1e-4
78
79 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
80 criterion = nn.CrossEntropyLoss()
81
82 nb_epochs, batch_size = 1000, 25
83
84 for k in range(nb_epochs):
85     acc_loss = 0.0
86
87     for input, targets in zip(
88         train_input.split(batch_size), train_targets.split(batch_size)
89     ):
90         output = model(input)
91         loss = criterion(output, targets)
92         acc_loss += loss.item()
93
94         optimizer.zero_grad()
95         loss.backward()
96         optimizer.step()
97
98     nb_train_errors = 0
99     for input, targets in zip(
100         train_input.split(batch_size), train_targets.split(batch_size)
101     ):
102         wta = model(input).argmax(1)
103         nb_train_errors += (wta != targets).long().sum()
104     train_error = nb_train_errors / train_input.size(0)
105
106     print(f"loss {k} {acc_loss:.02f} {train_error*100:.02f}%")
107
108     if train_error == 0:
109         break
110
111 ######################################################################
112
113 sg=25
114
115 input, targets = train_input, train_targets
116
117 grid = torch.linspace(-1.2,1.2,sg)
118 grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2)
119
120 for l, m in enumerate(model):
121     with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
122         f.write(
123             """\\addplot[
124     scatter src=explicit symbolic,
125     scatter/classes={0={blue}, 1={red}},
126     scatter, mark=*, only marks, mark options={mark size=0.5},
127 ]%
128 table[meta=label] {
129 x y label
130 """
131         )
132         for k in range(512):
133             f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
134         f.write("};\n")
135
136         g = grid.reshape(sg,sg,-1)
137         for i in range(g.size(0)):
138             for j in range(g.size(1)):
139                 if j == 0:
140                     pre="\\draw[black!25,very thin] "
141                 else:
142                     pre="--"
143                 f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
144             f.write(";\n")
145
146         for j in range(g.size(1)):
147             for i in range(g.size(0)):
148                 if i == 0:
149                     pre="\\draw[black!25,very thin] "
150                 else:
151                     pre="--"
152                 f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
153             f.write(";\n")
154
155         # add the decision line
156
157         if l == len(model) - 1:
158             u = torch.tensor([[1.0, -1.0]])
159             phi = model[-1]
160             a, b = (u @ phi.weight).squeeze(), (u @ phi.bias).item()
161             p = a * (b / (a @ a.t()).item())
162             f.write(
163                 f"\\draw[black,thick] ({p[0]-a[1]},{p[1]+a[0]}) -- ({p[0]+a[1]},{p[1]-a[0]});"
164             )
165
166     input, grid = m(input), m(grid)
167
168 ######################################################################