Update.
[picoclvr.git] / qmlp.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
9
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
12
13 # Written by Francois Fleuret <francois@fleuret.org>
14
15 import math, sys
16
17 import torch, torchvision
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 ######################################################################
23
24 nb_quantization_levels = 101
25
26
27 def quantize(x, xmin, xmax):
28     return (
29         ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
30         .long()
31         .clamp(min=0, max=nb_quantization_levels - 1)
32     )
33
34
35 def dequantize(q, xmin, xmax):
36     return q / nb_quantization_levels * (xmax - xmin) + xmin
37
38
39 ######################################################################
40
41
42 def generate_sets_and_params(
43     batch_nb_mlps,
44     nb_samples,
45     batch_size,
46     nb_epochs,
47     device=torch.device("cpu"),
48     print_log=False,
49     save_as_examples=False,
50 ):
51     data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
52     data_targets = torch.zeros(
53         batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
54     )
55
56     nb_rec = 8
57     nb_values = 2  # more increases the min-max gap
58
59     rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
60
61     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
62         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
63         nb = i.sum()
64         support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
65         support = support.sort(-1).values
66         support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
67
68         x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
69         y = (
70             (
71                 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
72                 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
73                 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
74                 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
75             )
76             .max(dim=1)
77             .values
78         )
79
80         data_input[i], data_targets[i], rec_support[i] = x, y, support
81
82     train_input, train_targets = (
83         data_input[:, :nb_samples],
84         data_targets[:, :nb_samples],
85     )
86     test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
87
88     q_train_input = quantize(train_input, -1, 1)
89     train_input = dequantize(q_train_input, -1, 1)
90
91     q_test_input = quantize(test_input, -1, 1)
92     test_input = dequantize(q_test_input, -1, 1)
93
94     if save_as_examples:
95         a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1
96         xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1),
97                         a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2)
98         xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1)
99         print(f"{xf.size()=} {x.size()=}")
100         yf = (
101             (
102                 (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long()
103                 * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long()
104                 * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long()
105                 * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long()
106             )
107             .max(dim=1)
108             .values
109         )
110
111         full_input, full_targets = xf,yf
112
113         q_full_input = quantize(full_input, -1, 1)
114         full_input = dequantize(q_full_input, -1, 1)
115
116         for k in range(q_full_input[:10].size(0)):
117             with open(f"example_full_{k:04d}.dat", "w") as f:
118                 for u, c in zip(full_input[k], full_targets[k]):
119                     f.write(f"{c} {u[0].item()} {u[1].item()}\n")
120
121         for k in range(q_train_input[:10].size(0)):
122             with open(f"example_train_{k:04d}.dat", "w") as f:
123                 for u, c in zip(train_input[k], train_targets[k]):
124                     f.write(f"{c} {u[0].item()} {u[1].item()}\n")
125
126     hidden_dim = 32
127     w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
128     b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
129     w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
130         hidden_dim
131     )
132     b2 = torch.zeros(batch_nb_mlps, 2, device=device)
133
134     w1.requires_grad_()
135     b1.requires_grad_()
136     w2.requires_grad_()
137     b2.requires_grad_()
138     optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
139
140     criterion = nn.CrossEntropyLoss()
141     criterion.to(device)
142
143     for k in range(nb_epochs):
144         acc_train_loss = 0.0
145         nb_train_errors = 0
146
147         for input, targets in zip(
148             train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
149         ):
150             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
151             h = F.relu(h)
152             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
153             loss = F.cross_entropy(
154                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
155             )
156             acc_train_loss += loss.item() * input.size(0)
157
158             wta = output.argmax(-1)
159             nb_train_errors += (wta != targets).long().sum(-1)
160
161             optimizer.zero_grad()
162             loss.backward()
163             optimizer.step()
164
165         with torch.no_grad():
166             for p in [w1, b1, w2, b2]:
167                 m = (
168                     torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
169                 ).long()
170                 pq = quantize(p, -2, 2)
171                 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
172
173         train_error = nb_train_errors / train_input.size(1)
174         acc_train_loss = acc_train_loss / train_input.size(1)
175
176         # print(f"{k=} {acc_train_loss=} {train_error=}")
177
178     acc_test_loss = 0
179     nb_test_errors = 0
180
181     for input, targets in zip(
182         test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
183     ):
184         h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
185         h = F.relu(h)
186         output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
187         loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
188         acc_test_loss += loss.item() * input.size(0)
189
190         wta = output.argmax(-1)
191         nb_test_errors += (wta != targets).long().sum(-1)
192
193     test_error = nb_test_errors / test_input.size(1)
194     q_params = torch.cat(
195         [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
196     )
197     q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
198         batch_nb_mlps, -1
199     )
200     q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
201         batch_nb_mlps, -1
202     )
203
204     return q_train_set, q_test_set, q_params, test_error
205
206
207 ######################################################################
208
209
210 def evaluate_q_params(
211         q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
212         save_as_examples=False,
213 ):
214     errors = []
215     nb_mlps = q_params.size(0)
216
217     for n in range(0, nb_mlps, nb_mlps_per_batch):
218         batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
219         batch_q_params = q_params[n : n + batch_nb_mlps]
220         batch_q_set = q_set[n : n + batch_nb_mlps]
221         hidden_dim = 32
222         w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
223         b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
224         w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
225         b2 = torch.empty(batch_nb_mlps, 2, device=device)
226
227         with torch.no_grad():
228             k = 0
229             for p in [w1, b1, w2, b2]:
230                 print(f"{p.size()=}")
231                 x = dequantize(
232                     batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
233                 ).view(p.size())
234                 p.copy_(x)
235                 k += p.numel() // batch_nb_mlps
236
237         batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
238         data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
239         data_targets = batch_q_set[:, :, 2].to(device)
240
241         print(f"{data_input.size()=} {data_targets.size()=}")
242
243         criterion = nn.CrossEntropyLoss()
244         criterion.to(device)
245
246         acc_loss = 0.0
247         nb_errors = 0
248
249         for input, targets in zip(
250             data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
251         ):
252             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
253             h = F.relu(h)
254             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
255             loss = F.cross_entropy(
256                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
257             )
258             acc_loss += loss.item() * input.size(0)
259             wta = output.argmax(-1)
260             nb_errors += (wta != targets).long().sum(-1)
261
262         errors.append(nb_errors / data_input.size(1))
263         acc_loss = acc_loss / data_input.size(1)
264
265     return torch.cat(errors)
266
267
268 ######################################################################
269
270
271 def generate_sequence_and_test_set(
272     nb_mlps,
273     nb_samples,
274     batch_size,
275     nb_epochs,
276     device,
277     nb_mlps_per_batch=1024,
278 ):
279     seqs, q_test_sets, test_errors = [], [], []
280
281     for n in range(0, nb_mlps, nb_mlps_per_batch):
282         q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
283             batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
284             nb_samples=nb_samples,
285             batch_size=batch_size,
286             nb_epochs=nb_epochs,
287             device=device,
288         )
289
290         seqs.append(
291             torch.cat(
292                 [
293                     q_train_set,
294                     q_train_set.new_full(
295                         (
296                             q_train_set.size(0),
297                             1,
298                         ),
299                         nb_quantization_levels,
300                     ),
301                     q_params,
302                 ],
303                 dim=-1,
304             )
305         )
306
307         q_test_sets.append(q_test_set)
308         test_errors.append(test_error)
309
310     seq = torch.cat(seqs)
311     q_test_set = torch.cat(q_test_sets)
312     test_error = torch.cat(test_errors)
313
314     return seq, q_test_set, test_error
315
316
317 ######################################################################
318
319 if __name__ == "__main__":
320     import time
321
322     batch_nb_mlps, nb_samples = 128, 250
323
324     generate_sets_and_params(
325         batch_nb_mlps=10,
326         nb_samples=nb_samples,
327         batch_size=25,
328         nb_epochs=100,
329         device=torch.device("cpu"),
330         print_log=False,
331         save_as_examples=True,
332     )
333
334     exit(0)
335
336     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
337
338     start_time = time.perf_counter()
339
340     data = []
341
342     seq, q_test_set, test_error = generate_sequence_and_test_set(
343         nb_mlps=batch_nb_mlps,
344         nb_samples=nb_samples,
345         device=device,
346         batch_size=25,
347         nb_epochs=250,
348         nb_mlps_per_batch=17,
349     )
350
351     end_time = time.perf_counter()
352     print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
353
354     q_train_set = seq[:, : nb_samples * 3]
355     q_params = seq[:, nb_samples * 3 + 1 :]
356     print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
357     error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
358     print(f"train {error_train*100}%")
359     error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
360     print(f"test {error_test*100}%")