572cde1d5643bc45700090e66715d724f8a3259d
[picoclvr.git] / qmlp.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
9
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
12
13 # Written by Francois Fleuret <francois@fleuret.org>
14
15 import math, sys
16
17 import torch, torchvision
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 ######################################################################
23
24 nb_quantization_levels = 101
25
26
27 def quantize(x, xmin, xmax):
28     return (
29         ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
30         .long()
31         .clamp(min=0, max=nb_quantization_levels - 1)
32     )
33
34
35 def dequantize(q, xmin, xmax):
36     return q / nb_quantization_levels * (xmax - xmin) + xmin
37
38
39 ######################################################################
40
41
42 def generate_sets_and_params(
43     batch_nb_mlps,
44     nb_samples,
45     batch_size,
46     nb_epochs,
47     device=torch.device("cpu"),
48     print_log=False,
49     save_as_examples=False,
50 ):
51     data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
52     data_targets = torch.zeros(
53         batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
54     )
55
56     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
57         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
58         nb = i.sum()
59
60         nb_rec = 8
61         nb_values = 2  # more increases the min-max gap
62         support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
63         support = support.sort(-1).values
64         support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
65
66         x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
67         y = (
68             (
69                 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
70                 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
71                 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
72                 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
73             )
74             .max(dim=1)
75             .values
76         )
77
78         data_input[i], data_targets[i] = x, y
79
80     train_input, train_targets = (
81         data_input[:, :nb_samples],
82         data_targets[:, :nb_samples],
83     )
84     test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
85
86     q_train_input = quantize(train_input, -1, 1)
87     train_input = dequantize(q_train_input, -1, 1)
88     train_targets = train_targets
89
90     q_test_input = quantize(test_input, -1, 1)
91     test_input = dequantize(q_test_input, -1, 1)
92     test_targets = test_targets
93
94     if save_as_examples:
95         for k in range(q_train_input.size(0)):
96             with open(f"example_{k:04d}.dat", "w") as f:
97                 for u, c in zip(train_input[k], train_targets[k]):
98                     f.write(f"{c} {u[0].item()} {u[1].item()}\n")
99
100     hidden_dim = 32
101     w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
102     b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
103     w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
104         hidden_dim
105     )
106     b2 = torch.zeros(batch_nb_mlps, 2, device=device)
107
108     w1.requires_grad_()
109     b1.requires_grad_()
110     w2.requires_grad_()
111     b2.requires_grad_()
112     optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
113
114     criterion = nn.CrossEntropyLoss()
115     criterion.to(device)
116
117     for k in range(nb_epochs):
118         acc_train_loss = 0.0
119         nb_train_errors = 0
120
121         for input, targets in zip(
122             train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
123         ):
124             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
125             h = F.relu(h)
126             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
127             loss = F.cross_entropy(
128                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
129             )
130             acc_train_loss += loss.item() * input.size(0)
131
132             wta = output.argmax(-1)
133             nb_train_errors += (wta != targets).long().sum(-1)
134
135             optimizer.zero_grad()
136             loss.backward()
137             optimizer.step()
138
139         with torch.no_grad():
140             for p in [w1, b1, w2, b2]:
141                 m = (
142                     torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
143                 ).long()
144                 pq = quantize(p, -2, 2)
145                 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
146
147         train_error = nb_train_errors / train_input.size(1)
148         acc_train_loss = acc_train_loss / train_input.size(1)
149
150         # print(f"{k=} {acc_train_loss=} {train_error=}")
151
152     acc_test_loss = 0
153     nb_test_errors = 0
154
155     for input, targets in zip(
156         test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
157     ):
158         h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
159         h = F.relu(h)
160         output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
161         loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
162         acc_test_loss += loss.item() * input.size(0)
163
164         wta = output.argmax(-1)
165         nb_test_errors += (wta != targets).long().sum(-1)
166
167     test_error = nb_test_errors / test_input.size(1)
168     q_params = torch.cat(
169         [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
170     )
171     q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
172         batch_nb_mlps, -1
173     )
174     q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
175         batch_nb_mlps, -1
176     )
177
178     return q_train_set, q_test_set, q_params, test_error
179
180
181 ######################################################################
182
183
184 def evaluate_q_params(
185         q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
186         save_as_examples=False,
187 ):
188     errors = []
189     nb_mlps = q_params.size(0)
190
191     for n in range(0, nb_mlps, nb_mlps_per_batch):
192         batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
193         batch_q_params = q_params[n : n + batch_nb_mlps]
194         batch_q_set = q_set[n : n + batch_nb_mlps]
195         hidden_dim = 32
196         w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
197         b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
198         w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
199         b2 = torch.empty(batch_nb_mlps, 2, device=device)
200
201         with torch.no_grad():
202             k = 0
203             for p in [w1, b1, w2, b2]:
204                 print(f"{p.size()=}")
205                 x = dequantize(
206                     batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
207                 ).view(p.size())
208                 p.copy_(x)
209                 k += p.numel() // batch_nb_mlps
210
211         batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
212         data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
213         data_targets = batch_q_set[:, :, 2].to(device)
214
215         print(f"{data_input.size()=} {data_targets.size()=}")
216
217         criterion = nn.CrossEntropyLoss()
218         criterion.to(device)
219
220         acc_loss = 0.0
221         nb_errors = 0
222
223         for input, targets in zip(
224             data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
225         ):
226             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
227             h = F.relu(h)
228             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
229             loss = F.cross_entropy(
230                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
231             )
232             acc_loss += loss.item() * input.size(0)
233             wta = output.argmax(-1)
234             nb_errors += (wta != targets).long().sum(-1)
235
236         errors.append(nb_errors / data_input.size(1))
237         acc_loss = acc_loss / data_input.size(1)
238
239     return torch.cat(errors)
240
241
242 ######################################################################
243
244
245 def generate_sequence_and_test_set(
246     nb_mlps,
247     nb_samples,
248     batch_size,
249     nb_epochs,
250     device,
251     nb_mlps_per_batch=1024,
252 ):
253     seqs, q_test_sets, test_errors = [], [], []
254
255     for n in range(0, nb_mlps, nb_mlps_per_batch):
256         q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
257             batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
258             nb_samples=nb_samples,
259             batch_size=batch_size,
260             nb_epochs=nb_epochs,
261             device=device,
262         )
263
264         seqs.append(
265             torch.cat(
266                 [
267                     q_train_set,
268                     q_train_set.new_full(
269                         (
270                             q_train_set.size(0),
271                             1,
272                         ),
273                         nb_quantization_levels,
274                     ),
275                     q_params,
276                 ],
277                 dim=-1,
278             )
279         )
280
281         q_test_sets.append(q_test_set)
282         test_errors.append(test_error)
283
284     seq = torch.cat(seqs)
285     q_test_set = torch.cat(q_test_sets)
286     test_error = torch.cat(test_errors)
287
288     return seq, q_test_set, test_error
289
290
291 ######################################################################
292
293 if __name__ == "__main__":
294     import time
295
296     batch_nb_mlps, nb_samples = 128, 2500
297
298     generate_sets_and_params(
299         batch_nb_mlps=10,
300         nb_samples=nb_samples,
301         batch_size=25,
302         nb_epochs=100,
303         device=torch.device("cpu"),
304         print_log=False,
305         save_as_examples=True,
306     )
307
308     exit(0)
309
310     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
311
312     start_time = time.perf_counter()
313
314     data = []
315
316     seq, q_test_set, test_error = generate_sequence_and_test_set(
317         nb_mlps=batch_nb_mlps,
318         nb_samples=nb_samples,
319         device=device,
320         batch_size=25,
321         nb_epochs=250,
322         nb_mlps_per_batch=17,
323     )
324
325     end_time = time.perf_counter()
326     print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
327
328     q_train_set = seq[:, : nb_samples * 3]
329     q_params = seq[:, nb_samples * 3 + 1 :]
330     print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
331     error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
332     print(f"train {error_train*100}%")
333     error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
334     print(f"test {error_test*100}%")