Update.
[picoclvr.git] / qmlp.py
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
9
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
12
13 # Written by Francois Fleuret <francois@fleuret.org>
14
15 import math, sys
16
17 import torch, torchvision
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 ######################################################################
23
24 nb_quantization_levels = 101
25
26
27 def quantize(x, xmin, xmax):
28     return (
29         ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
30         .long()
31         .clamp(min=0, max=nb_quantization_levels - 1)
32     )
33
34
35 def dequantize(q, xmin, xmax):
36     return q / nb_quantization_levels * (xmax - xmin) + xmin
37
38
39 ######################################################################
40
41
42
43
44 def generate_sets_and_params(
45     batch_nb_mlps,
46     nb_samples,
47     batch_size,
48     nb_epochs,
49     device=torch.device("cpu"),
50     print_log=False,
51 ):
52     data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
53     data_targets = torch.zeros(
54         batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
55     )
56
57     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
58         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
59         nb = i.sum()
60
61         nb_rec = 2
62         support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1
63         support = support.sort(-1).values
64         support = support[:, :, :, torch.tensor([0, 2])].view(nb, nb_rec, 4)
65
66         x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
67         y = (
68             (
69                 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
70                 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
71                 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
72                 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
73             )
74             .max(dim=1)
75             .values
76         )
77
78         data_input[i], data_targets[i] = x, y
79
80     train_input, train_targets = (
81         data_input[:, :nb_samples],
82         data_targets[:, :nb_samples],
83     )
84     test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
85
86     q_train_input = quantize(train_input, -1, 1)
87     train_input = dequantize(q_train_input, -1, 1)
88     train_targets = train_targets
89
90     q_test_input = quantize(test_input, -1, 1)
91     test_input = dequantize(q_test_input, -1, 1)
92     test_targets = test_targets
93
94     hidden_dim = 32
95     w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
96     b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
97     w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
98     b2 = torch.zeros(batch_nb_mlps, 2, device=device)
99
100     w1.requires_grad_()
101     b1.requires_grad_()
102     w2.requires_grad_()
103     b2.requires_grad_()
104     optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
105
106     criterion = nn.CrossEntropyLoss()
107     criterion.to(device)
108
109     for k in range(nb_epochs):
110         acc_train_loss = 0.0
111         nb_train_errors = 0
112
113         for input, targets in zip(
114             train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
115         ):
116             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
117             h = F.relu(h)
118             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
119             loss = F.cross_entropy(
120                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
121             )
122             acc_train_loss += loss.item() * input.size(0)
123
124             wta = output.argmax(-1)
125             nb_train_errors += (wta != targets).long().sum(-1)
126
127             optimizer.zero_grad()
128             loss.backward()
129             optimizer.step()
130
131         with torch.no_grad():
132             for p in [w1, b1, w2, b2]:
133                 m = (
134                     torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
135                 ).long()
136                 pq = quantize(p, -2, 2)
137                 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
138
139         train_error = nb_train_errors / train_input.size(1)
140         acc_train_loss = acc_train_loss / train_input.size(1)
141
142         # print(f"{k=} {acc_train_loss=} {train_error=}")
143
144     q_params = torch.cat(
145         [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
146     )
147     q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
148         batch_nb_mlps, -1
149     )
150     q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
151         batch_nb_mlps, -1
152     )
153
154     return q_train_set, q_test_set, q_params
155
156
157 ######################################################################
158
159
160 def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024):
161
162     errors = []
163     nb_mlps = q_params.size(0)
164
165     for n in range(0,nb_mlps,nb_mlps_per_batch):
166         batch_nb_mlps = min(nb_mlps_per_batch,nb_mlps-n)
167         batch_q_params = q_params[n:n+batch_nb_mlps]
168         batch_q_set = q_set[n:n+batch_nb_mlps]
169         hidden_dim = 32
170         w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
171         b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
172         w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
173         b2 = torch.empty(batch_nb_mlps, 2, device=device)
174
175         with torch.no_grad():
176             k = 0
177             for p in [w1, b1, w2, b2]:
178                 print(f"{p.size()=}")
179                 x = dequantize(batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2).view(
180                     p.size()
181                 )
182                 p.copy_(x)
183                 k += p.numel() // batch_nb_mlps
184
185         batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
186         data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
187         data_targets = batch_q_set[:, :, 2].to(device)
188
189         print(f"{data_input.size()=} {data_targets.size()=}")
190
191         criterion = nn.CrossEntropyLoss()
192         criterion.to(device)
193
194         acc_loss = 0.0
195         nb_errors = 0
196
197         for input, targets in zip(
198             data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
199         ):
200             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
201             h = F.relu(h)
202             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
203             loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
204             acc_loss += loss.item() * input.size(0)
205             wta = output.argmax(-1)
206             nb_errors += (wta != targets).long().sum(-1)
207
208         errors.append(nb_errors / data_input.size(1))
209         acc_loss = acc_loss / data_input.size(1)
210
211
212     return torch.cat(errors)
213
214
215 ######################################################################
216
217
218 def generate_sequence_and_test_set(
219     nb_mlps,
220     nb_samples,
221     batch_size,
222     nb_epochs,
223     device,
224     nb_mlps_per_batch=1024,
225 ):
226
227     seqs, q_test_sets = [],[]
228
229     for n in range(0,nb_mlps,nb_mlps_per_batch):
230         q_train_set, q_test_set, q_params = generate_sets_and_params(
231             batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n),
232             nb_samples=nb_samples,
233             batch_size=batch_size,
234             nb_epochs=nb_epochs,
235             device=device,
236         )
237
238         seqs.append(torch.cat(
239             [
240                 q_train_set,
241                 q_train_set.new_full(
242                     (
243                         q_train_set.size(0),
244                         1,
245                     ),
246                     nb_quantization_levels,
247                 ),
248                 q_params,
249             ],
250             dim=-1,
251         ))
252
253         q_test_sets.append(q_test_set)
254
255     seq = torch.cat(seqs)
256     q_test_set = torch.cat(q_test_sets)
257
258     return seq, q_test_set
259
260
261 ######################################################################
262
263 if __name__ == "__main__":
264     import time
265
266     batch_nb_mlps, nb_samples = 128, 500
267
268     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
269
270     start_time = time.perf_counter()
271
272     data = []
273
274     seq, q_test_set = generate_sequence_and_test_set(
275         nb_mlps=batch_nb_mlps,
276         nb_samples=nb_samples,
277         device=device,
278         batch_size=25,
279         nb_epochs=250,
280         nb_mlps_per_batch=17
281     )
282
283     end_time = time.perf_counter()
284     print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
285
286     q_train_set = seq[:, : nb_samples * 3]
287     q_params = seq[:, nb_samples * 3 + 1 :]
288     print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
289     error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
290     print(f"train {error_train*100}%")
291     error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
292     print(f"test {error_test*100}%")