Update.
[picoclvr.git] / qmlp.y
1 #!/usr/bin/env python
2
3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
9
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
12
13 # Written by Francois Fleuret <francois@fleuret.org>
14
15 import math, sys
16
17 import torch, torchvision
18
19 from torch import nn
20 from torch.nn import functional as F
21
22 ######################################################################
23
24 nb_quantization_levels = 101
25
26
27 def quantize(x, xmin, xmax):
28     return (
29         ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
30         .long()
31         .clamp(min=0, max=nb_quantization_levels - 1)
32     )
33
34
35 def dequantize(q, xmin, xmax):
36     return q / nb_quantization_levels * (xmax - xmin) + xmin
37
38
39 ######################################################################
40
41
42 def create_model():
43     hidden_dim = 32
44
45     model = nn.Sequential(
46         nn.Linear(2, hidden_dim),
47         nn.ReLU(),
48         nn.Linear(hidden_dim, hidden_dim),
49         nn.ReLU(),
50         nn.Linear(hidden_dim, 2),
51     )
52
53     return model
54
55
56 ######################################################################
57
58
59 def generate_sets_and_params(
60     nb_mlps,
61     nb_samples,
62     batch_size,
63     nb_epochs,
64     device=torch.device("cpu"),
65     print_log=False,
66 ):
67     data_input = torch.zeros(nb_mlps, 2 * nb_samples, 2, device=device)
68     data_targets = torch.zeros(
69         nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
70     )
71
72     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
73         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
74         nb = i.sum()
75         print(f"{nb=}")
76
77         nb_rec = 2
78         support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1
79         support = support.sort(-1).values
80         support = support[:, :, :, torch.tensor([0, 2])].view(nb, nb_rec, 4)
81
82         x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
83         y = (
84             (
85                 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
86                 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
87                 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
88                 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
89             )
90             .max(dim=1)
91             .values
92         )
93
94         data_input[i], data_targets[i] = x, y
95
96     train_input, train_targets = (
97         data_input[:, :nb_samples],
98         data_targets[:, :nb_samples],
99     )
100     test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
101
102     q_train_input = quantize(train_input, -1, 1)
103     train_input = dequantize(q_train_input, -1, 1)
104     train_targets = train_targets
105
106     q_test_input = quantize(test_input, -1, 1)
107     test_input = dequantize(q_test_input, -1, 1)
108     test_targets = test_targets
109
110     hidden_dim = 32
111     w1 = torch.randn(nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
112     b1 = torch.zeros(nb_mlps, hidden_dim, device=device)
113     w2 = torch.randn(nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
114     b2 = torch.zeros(nb_mlps, 2, device=device)
115
116     w1.requires_grad_()
117     b1.requires_grad_()
118     w2.requires_grad_()
119     b2.requires_grad_()
120     optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
121
122     criterion = nn.CrossEntropyLoss()
123     criterion.to(device)
124
125     for k in range(nb_epochs):
126         acc_train_loss = 0.0
127         nb_train_errors = 0
128
129         for input, targets in zip(
130             train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
131         ):
132             h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
133             h = F.relu(h)
134             output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
135             loss = F.cross_entropy(
136                 output.reshape(-1, output.size(-1)), targets.reshape(-1)
137             )
138             acc_train_loss += loss.item() * input.size(0)
139
140             wta = output.argmax(-1)
141             nb_train_errors += (wta != targets).long().sum(-1)
142
143             optimizer.zero_grad()
144             loss.backward()
145             optimizer.step()
146
147         with torch.no_grad():
148             for p in [w1, b1, w2, b2]:
149                 m = (
150                     torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
151                 ).long()
152                 pq = quantize(p, -2, 2)
153                 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
154
155         train_error = nb_train_errors / train_input.size(1)
156         acc_train_loss = acc_train_loss / train_input.size(1)
157
158         # print(f"{k=} {acc_train_loss=} {train_error=}")
159
160     q_params = torch.cat(
161         [quantize(p.view(nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
162     )
163     q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
164         nb_mlps, -1
165     )
166     q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
167         nb_mlps, -1
168     )
169
170     return q_train_set, q_test_set, q_params
171
172
173 ######################################################################
174
175
176 def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu")):
177     nb_mlps = q_params.size(0)
178     hidden_dim = 32
179     w1 = torch.empty(nb_mlps, hidden_dim, 2, device=device)
180     b1 = torch.empty(nb_mlps, hidden_dim, device=device)
181     w2 = torch.empty(nb_mlps, 2, hidden_dim, device=device)
182     b2 = torch.empty(nb_mlps, 2, device=device)
183
184     with torch.no_grad():
185         k = 0
186         for p in [w1, b1, w2, b2]:
187             print(f"{p.size()=}")
188             x = dequantize(q_params[:, k : k + p.numel() // nb_mlps], -2, 2).view(
189                 p.size()
190             )
191             p.copy_(x)
192             k += p.numel() // nb_mlps
193
194     q_set = q_set.view(nb_mlps, -1, 3)
195     data_input = dequantize(q_set[:, :, :2], -1, 1).to(device)
196     data_targets = q_set[:, :, 2].to(device)
197
198     print(f"{data_input.size()=} {data_targets.size()=}")
199
200     criterion = nn.CrossEntropyLoss()
201     criterion.to(device)
202
203     acc_loss = 0.0
204     nb_errors = 0
205
206     for input, targets in zip(
207         data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
208     ):
209         h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
210         h = F.relu(h)
211         output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
212         loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
213         acc_loss += loss.item() * input.size(0)
214         wta = output.argmax(-1)
215         nb_errors += (wta != targets).long().sum(-1)
216
217     error = nb_errors / data_input.size(1)
218     acc_loss = acc_loss / data_input.size(1)
219
220     return error
221
222
223 ######################################################################
224
225
226 def generate_sequence_and_test_set(
227     nb_mlps,
228     nb_samples,
229     batch_size,
230     nb_epochs,
231     device,
232 ):
233     q_train_set, q_test_set, q_params = generate_sets_and_params(
234         nb_mlps,
235         nb_samples,
236         batch_size,
237         nb_epochs,
238         device=device,
239     )
240
241     input = torch.cat(
242         [
243             q_train_set,
244             q_train_set.new_full(
245                 (
246                     q_train_set.size(0),
247                     1,
248                 ),
249                 nb_quantization_levels,
250             ),
251             q_params,
252         ],
253         dim=-1,
254     )
255
256     print(f"SANITY #1 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
257
258     ar_mask = (
259         (torch.arange(input.size(0), device=input.device) > q_train_set.size(0) + 1)
260         .long()
261         .view(1, -1)
262         .reshape(nb_mlps, -1)
263     )
264
265     return input, ar_mask, q_test_set
266
267
268 ######################################################################
269
270 if __name__ == "__main__":
271     import time
272
273     nb_mlps, nb_samples = 128, 200
274
275     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
276
277     start_time = time.perf_counter()
278
279     data = []
280
281     for n in range(2):
282         data.append(
283             generate_sequence_and_test_set(
284                 nb_mlps=nb_mlps,
285                 nb_samples=nb_samples,
286                 device=device,
287                 batch_size=25,
288                 nb_epochs=250,
289             )
290         )
291
292     end_time = time.perf_counter()
293     nb = sum([i.size(0) for i, _, _ in data])
294     print(f"{nb / (end_time - start_time):.02f} samples per second")
295
296     for input, ar_mask, q_test_set in data:
297         q_train_set = input[:, : nb_samples * 3]
298         q_params = input[:, nb_samples * 3 + 1 :]
299         print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
300         error_train = evaluate_q_params(q_params, q_train_set)
301         print(f"train {error_train*100}%")
302         error_test = evaluate_q_params(q_params, q_test_set)
303         print(f"test {error_test*100}%")