Update.
[mygptrnn.git] / qmlp.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
9
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
12
13 # Written by Francois Fleuret <francois@fleuret.org>
14
15 import math, sys
16
17 import torch, torchvision
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 ######################################################################
23
24 nb_quantization_levels = 101
25
26
27 def quantize(x, xmin, xmax):
28     return (
29         ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
30         .long()
31         .clamp(min=0, max=nb_quantization_levels - 1)
32     )
33
34
35 def dequantize(q, xmin, xmax):
36     return q / nb_quantization_levels * (xmax - xmin) + xmin
37
38
39 ######################################################################
40
41
42 def generate_sets_and_params(
43     batch_nb_mlps,
44     nb_samples,
45     batch_size,
46     nb_epochs,
47     device=torch.device("cpu"),
48     print_log=False,
49     save_as_examples=False,
50 ):
51     data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
52     data_targets = torch.zeros(
53         batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
54     )
55
56     nb_rec = 8
57     nb_values = 2  # more increases the min-max gap
58
59     rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
60
61     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
62         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
63         nb = i.sum()
64         support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
65         support = support.sort(-1).values
66         support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
67
68         x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
69         y = (
70             (
71                 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
72                 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
73                 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
74                 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
75             )
76             .max(dim=1)
77             .values
78         )
79
80         data_input[i], data_targets[i], rec_support[i] = x, y, support
81
82     train_input, train_targets = (
83         data_input[:, :nb_samples],
84         data_targets[:, :nb_samples],
85     )
86     test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
87
88     q_train_input = quantize(train_input, -1, 1)
89     train_input = dequantize(q_train_input, -1, 1)
90
91     q_test_input = quantize(test_input, -1, 1)
92     test_input = dequantize(q_test_input, -1, 1)
93
94     if save_as_examples:
95         a = (
96             2
97             * torch.arange(nb_quantization_levels).float()
98             / (nb_quantization_levels - 1)
99             - 1
100         )
101         xf = torch.cat(
102             [
103                 a[:, None, None].expand(
104                     nb_quantization_levels, nb_quantization_levels, 1
105                 ),
106                 a[None, :, None].expand(
107                     nb_quantization_levels, nb_quantization_levels, 1
108                 ),
109             ],
110             2,
111         )
112         xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
113         print(f"{xf.size()=} {x.size()=}")
114         yf = (
115             (
116                 (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
117                 * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
118                 * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
119                 * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
120             )
121             .max(dim=1)
122             .values
123         )
124
125         full_input, full_targets = xf, yf
126
127         q_full_input = quantize(full_input, -1, 1)
128         full_input = dequantize(q_full_input, -1, 1)
129
130         for k in range(q_full_input[:10].size(0)):
131             with open(f"example_full_{k:04d}.dat", "w") as f:
132                 for u, c in zip(full_input[k], full_targets[k]):
133                     f.write(f"{c} {u[0].item()} {u[1].item()}\n")
134
135         for k in range(q_train_input[:10].size(0)):
136             with open(f"example_train_{k:04d}.dat", "w") as f:
137                 for u, c in zip(train_input[k], train_targets[k]):
138                     f.write(f"{c} {u[0].item()} {u[1].item()}\n")
139
140     hidden_dim = 32
141     w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
142     b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
143     w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
144         hidden_dim
145     )
146     b2 = torch.zeros(batch_nb_mlps, 2, device=device)
147
148     w1.requires_grad_()
149     b1.requires_grad_()
150     w2.requires_grad_()
151     b2.requires_grad_()
152     optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
153
154     criterion = nn.CrossEntropyLoss()
155     criterion.to(device)
156
157     for k in range(nb_epochs):
158         acc_train_loss = 0.0
159         nb_train_errors = 0
160
161         for input, targets in zip(
162             train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
163         ):
164             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
165             h = F.relu(h)
166             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
167             loss = F.cross_entropy(
168                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
169             )
170             acc_train_loss += loss.item() * input.size(0)
171
172             wta = output.argmax(-1)
173             nb_train_errors += (wta != targets).long().sum(-1)
174
175             optimizer.zero_grad()
176             loss.backward()
177             optimizer.step()
178
179         with torch.no_grad():
180             for p in [w1, b1, w2, b2]:
181                 m = (
182                     torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
183                 ).long()
184                 pq = quantize(p, -2, 2)
185                 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
186
187         train_error = nb_train_errors / train_input.size(1)
188         acc_train_loss = acc_train_loss / train_input.size(1)
189
190         # print(f"{k=} {acc_train_loss=} {train_error=}")
191
192     acc_test_loss = 0
193     nb_test_errors = 0
194
195     for input, targets in zip(
196         test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
197     ):
198         h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
199         h = F.relu(h)
200         output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
201         loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
202         acc_test_loss += loss.item() * input.size(0)
203
204         wta = output.argmax(-1)
205         nb_test_errors += (wta != targets).long().sum(-1)
206
207     test_error = nb_test_errors / test_input.size(1)
208     q_params = torch.cat(
209         [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
210     )
211     q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
212         batch_nb_mlps, -1
213     )
214     q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
215         batch_nb_mlps, -1
216     )
217
218     return q_train_set, q_test_set, q_params, test_error
219
220
221 ######################################################################
222
223
224 def evaluate_q_params(
225     q_params,
226     q_set,
227     batch_size=25,
228     device=torch.device("cpu"),
229     nb_mlps_per_batch=1024,
230     save_as_examples=False,
231 ):
232     errors = []
233     nb_mlps = q_params.size(0)
234
235     for n in range(0, nb_mlps, nb_mlps_per_batch):
236         batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
237         batch_q_params = q_params[n : n + batch_nb_mlps]
238         batch_q_set = q_set[n : n + batch_nb_mlps]
239         hidden_dim = 32
240         w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
241         b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
242         w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
243         b2 = torch.empty(batch_nb_mlps, 2, device=device)
244
245         with torch.no_grad():
246             k = 0
247             for p in [w1, b1, w2, b2]:
248                 print(f"{p.size()=}")
249                 x = dequantize(
250                     batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
251                 ).view(p.size())
252                 p.copy_(x)
253                 k += p.numel() // batch_nb_mlps
254
255         batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
256         data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
257         data_targets = batch_q_set[:, :, 2].to(device)
258
259         print(f"{data_input.size()=} {data_targets.size()=}")
260
261         criterion = nn.CrossEntropyLoss()
262         criterion.to(device)
263
264         acc_loss = 0.0
265         nb_errors = 0
266
267         for input, targets in zip(
268             data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
269         ):
270             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
271             h = F.relu(h)
272             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
273             loss = F.cross_entropy(
274                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
275             )
276             acc_loss += loss.item() * input.size(0)
277             wta = output.argmax(-1)
278             nb_errors += (wta != targets).long().sum(-1)
279
280         errors.append(nb_errors / data_input.size(1))
281         acc_loss = acc_loss / data_input.size(1)
282
283     return torch.cat(errors)
284
285
286 ######################################################################
287
288
289 def generate_sequence_and_test_set(
290     nb_mlps,
291     nb_samples,
292     batch_size,
293     nb_epochs,
294     device,
295     nb_mlps_per_batch=1024,
296 ):
297     seqs, q_test_sets, test_errors = [], [], []
298
299     for n in range(0, nb_mlps, nb_mlps_per_batch):
300         q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
301             batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
302             nb_samples=nb_samples,
303             batch_size=batch_size,
304             nb_epochs=nb_epochs,
305             device=device,
306         )
307
308         seqs.append(
309             torch.cat(
310                 [
311                     q_train_set,
312                     q_train_set.new_full(
313                         (
314                             q_train_set.size(0),
315                             1,
316                         ),
317                         nb_quantization_levels,
318                     ),
319                     q_params,
320                 ],
321                 dim=-1,
322             )
323         )
324
325         q_test_sets.append(q_test_set)
326         test_errors.append(test_error)
327
328     seq = torch.cat(seqs)
329     q_test_set = torch.cat(q_test_sets)
330     test_error = torch.cat(test_errors)
331
332     return seq, q_test_set, test_error
333
334
335 ######################################################################
336
337 if __name__ == "__main__":
338     import time
339
340     batch_nb_mlps, nb_samples = 128, 250
341
342     generate_sets_and_params(
343         batch_nb_mlps=10,
344         nb_samples=nb_samples,
345         batch_size=25,
346         nb_epochs=100,
347         device=torch.device("cpu"),
348         print_log=False,
349         save_as_examples=True,
350     )
351
352     exit(0)
353
354     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
355
356     start_time = time.perf_counter()
357
358     data = []
359
360     seq, q_test_set, test_error = generate_sequence_and_test_set(
361         nb_mlps=batch_nb_mlps,
362         nb_samples=nb_samples,
363         device=device,
364         batch_size=25,
365         nb_epochs=250,
366         nb_mlps_per_batch=17,
367     )
368
369     end_time = time.perf_counter()
370     print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
371
372     q_train_set = seq[:, : nb_samples * 3]
373     q_params = seq[:, nb_samples * 3 + 1 :]
374     print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
375     error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
376     print(f"train {error_train*100}%")
377     error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
378     print(f"test {error_test*100}%")