Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25
26 ####################
27 class ProblemDegradation(Problem):
28     def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
29         assert value_max // nb_state_tokens >= 2
30         self.nb_state_tokens = nb_state_tokens
31         self.nb_time_steps = nb_time_steps
32         self.value_max = value_max
33         self.hard = hard
34
35     def generate_sequences(self, nb):
36         x = (
37             torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
38         ).long() * self.value_max
39         seq = [x]
40
41         for t in range(self.nb_time_steps - 1):
42             v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
43             u = (v.max(dim=-1, keepdim=True).values == v).long()
44             n = (
45                 (u * x)
46                 .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
47                 .sum(dim=-1, keepdim=True)
48             )
49             m = 1 + ((n - 1) * torch.rand(n.size())).long()
50             x = (
51                 x
52                 + m * u.roll(shifts=-1, dims=-1)
53                 - n * u
54                 + (n - m) * u.roll(shifts=1, dims=-1)
55             )
56             seq.append(x)
57
58         if self.hard:
59             seq.reverse()
60
61         seq = torch.cat(seq, dim=1)
62         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
63
64     def compute_nb_correct(self, input, ar_mask, result):
65         nb_total = result.size(0)
66         nb_correct = 0
67         e = result.new_zeros(self.nb_state_tokens)
68
69         for seq in result:
70             states = list(seq.split(self.nb_state_tokens))
71             if self.hard:
72                 states.reverse()
73
74             d = states[0]
75             j = d.sort(descending=True).indices[0]
76             e.zero_()
77             e[j] = self.value_max
78             if (d - e).abs().sum() == 0:
79                 nb_errors = 0
80                 for k in range(len(states) - 1):
81                     d = states[k + 1] - states[k]
82                     j = d.sort(descending=False).indices[0]
83                     if (
84                         d[j] == 0
85                         or d[j] > self.value_max // 4
86                         or d[(j + 1) % e.size(0)] <= 0
87                         or d[(j + 1) % e.size(0)] >= -d[j]
88                     ):
89                         nb_errors += 1
90                     else:
91                         e.zero_()
92                         e[j] = d[j]
93                         e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
94                         e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
95                         if (d - e).abs().sum() > 0:
96                             nb_errors += 1
97                 if nb_errors == 0:
98                     nb_correct += 1
99
100         return nb_total, nb_correct
101
102     def seq2str(self, seq):
103         return " | ".join(
104             [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
105         )
106
107
108 ####################
109
110
111 class ProblemTwoTargets(Problem):
112     def __init__(self, len_total=10, len_targets=3):
113         assert len_targets >= 3
114         assert len_total >= 3 * len_targets - 1
115         self.len_total = len_total
116         self.len_targets = len_targets
117
118     def generate_sequences(self, nb):
119         k = torch.arange(self.len_total)[None, :]
120         s = torch.randint(10, (nb, self.len_total))
121         l = torch.rand(nb, self.len_total)
122         l = l * (k <= self.len_total - self.len_targets).long()
123         k1 = l.argmax(dim=1, keepdim=True)
124         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
125         s = s * m + 10 * (1 - m)
126         l = l * (
127             1
128             - (k + self.len_targets - 1 >= k1).long()
129             * (k < k1 + self.len_targets).long()
130         )
131         k2 = l.argmax(dim=1, keepdim=True)
132         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
133         s = s * m + 11 * (1 - m)
134         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
135         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
136         sequences = torch.cat(
137             (
138                 s,
139                 torch.full((nb, 1), 12),
140                 a1,
141                 torch.full((nb, 1), 12),
142                 a2,
143                 torch.full((nb, 1), 12),
144             ),
145             1,
146         )
147         ar_mask = (sequences == 12).long()
148         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
149         return sequences, ar_mask
150
151     def seq2str(self, seq):
152         return "".join("0123456789-+|"[x.item()] for x in seq)
153
154
155 ####################
156
157
158 class ProblemByHeart(Problem):
159     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
160         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
161         self.seq[:, len_prompt] = 10
162
163     def generate_sequences(self, nb):
164         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
165         ar_mask = (sequences == 10).long()
166         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
167         return sequences, ar_mask
168
169     def seq2str(self, seq):
170         return "".join("0123456789|"[x.item()] for x in seq)
171
172
173 ####################
174
175
176 class ProblemLearnOperator(Problem):
177     def __init__(self, nb_operators=100, len_source=6, len_result=9):
178         self.len_source = len_source
179         self.len_result = len_result
180         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
181         self.operators = F.one_hot(
182             torch.rand(nb_operators, len_result, len_source).argmax(-1),
183             num_classes=len_source,
184         )
185
186     def generate_sequences(self, nb):
187         nb_operators = torch.randint(self.operators.size(0), (nb,))
188         operators = self.operators[nb_operators]
189         nb_operators = (
190             nb_operators[:, None]
191             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
192         ) % 10
193         marker1 = torch.full((nb, 1), 10)
194         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
195         marker2 = torch.full((nb, 1), 11)
196         result = operators.bmm(source[:, :, None]).squeeze(-1)
197         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
198         ar_mask = (sequences == 11).long()
199         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
200         return sequences, ar_mask
201
202     def seq2str(self, seq):
203         return "".join("0123456789|>"[x.item()] for x in seq)
204
205
206 ####################
207
208
209 class ProblemGuessOperator(Problem):
210     def __init__(self, len_source=5, len_result=8):
211         self.len_source = len_source
212         self.len_result = len_result
213
214     def generate_sequences(self, nb):
215         operators = F.one_hot(
216             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
217             num_classes=self.len_source,
218         )
219         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
220         marker1 = torch.full((nb, 1), 10)
221         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
222         marker2 = torch.full((nb, 1), 11)
223         source2 = torch.randint(10, (nb, self.len_source))
224         marker3 = torch.full((nb, 1), 12)
225         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
226
227         sequences = torch.cat(
228             (source1, marker1, result1, marker2, source2, marker3, result2), 1
229         )
230         ar_mask = (sequences == 12).long()
231         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
232         return sequences, ar_mask
233
234     def seq2str(self, seq):
235         return "".join("0123456789>|~"[x.item()] for x in seq)
236
237
238 ####################
239
240
241 class ProblemAddition(Problem):
242     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
243         self.nb_digits = nb_digits
244         self.zero_padded = zero_padded
245         self.inverted_result = inverted_result
246         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
247         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
248
249     def tensorize(self, strings):
250         len_max = max([len(x) for x in strings])
251         return torch.cat(
252             [
253                 torch.tensor(
254                     [
255                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
256                         for s in strings
257                     ]
258                 )
259             ],
260             0,
261         )
262
263     def generate_sequences(self, nb):
264         sequences = []
265         for k in range(nb):
266             a, b = torch.randint(10**self.nb_digits, (2,))
267             c = a + b
268             a, b, c = str(a.item()), str(b.item()), str(c.item())
269             if self.zero_padded:
270                 a = "0" * (self.nb_digits - len(a)) + a
271                 b = "0" * (self.nb_digits - len(b)) + b
272                 c = "0" * (self.nb_digits + 1 - len(c)) + c
273             if self.inverted_result:
274                 c = c[::-1]
275             sequences.append(f"{a}+{b}={c}$")
276
277         sequences = self.tensorize(sequences)
278         ar_mask = (sequences == self.char2id["="]).long()
279         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
280         return sequences, ar_mask
281
282     def seq2str(self, seq):
283         return "".join(self.id2char[x.item()] for x in seq)
284
285
286 ####################
287
288
289 class ProblemMixing(Problem):
290     def __init__(self, height=4, width=4, nb_time_steps=9, hard=False):
291         self.height = height
292         self.width = width
293         self.nb_time_steps = nb_time_steps
294         self.hard = hard
295
296     def start_random(self, nb):
297         y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
298
299         # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
300
301         i = torch.arange(self.height).reshape(1,-1,1).expand(nb,self.height,self.width)
302         j = torch.arange(self.width).reshape(1,1,-1).expand(nb,self.height,self.width)
303
304         ri = torch.randint(self.height, (nb,)).reshape(nb,1,1)
305         rj = torch.randint(self.width, (nb,)).reshape(nb,1,1)
306
307         m = 1 - torch.logical_or(i==ri,j==rj).long().flatten(1)
308
309         y = (y * m + self.height * self.width * (1 - m)).reshape(
310             nb, self.height, self.width
311         )
312
313         return y
314
315     def start_error(self, x):
316         i = torch.arange(self.height, device=x.device).reshape(1,-1,1).expand_as(x)
317         j = torch.arange(self.width, device=x.device).reshape(1,1,-1).expand_as(x)
318
319         ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1)
320         rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)
321
322         m = 1 - torch.logical_or(i==ri,j==rj).long().flatten(1)
323
324         x = x.flatten(1)
325         u = torch.arange(self.height * self.width, device = x.device).reshape(1, -1)
326
327         d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
328         return d
329
330     def moves(self, x):
331         y = (
332             x[:, None, :, :]
333             .expand(-1, self.height * 2 + self.width * 2, -1, -1)
334             .clone()
335         )
336         k = 0
337
338         for i in range(self.height):
339             y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
340             k += 1
341             y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
342             k += 1
343
344         for j in range(self.width):
345             y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
346             k += 1
347             y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
348             k += 1
349
350         return y
351
352     def generate_sequences(self, nb):
353         x = self.start_random(nb)
354
355         seq = [x.flatten(1)]
356
357         for t in range(self.nb_time_steps - 1):
358             y = self.moves(x)
359             x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
360             seq.append(x.flatten(1))
361
362         if self.hard:
363             seq.reverse()
364
365         seq = torch.cat(seq, dim=1)
366         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
367
368     def compute_nb_correct(self, input, ar_mask, result):
369         a = [
370             x.reshape(result.size(0), self.height, self.width)
371             for x in result.split(self.height * self.width, dim=1)
372         ]
373         if self.hard:
374             a.reverse()
375
376         x = a[0]
377
378         d = self.start_error(x)
379
380         for t in range(self.nb_time_steps - 1):
381             x0, x = a[t], a[t + 1]
382             y = self.moves(x0)
383             d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
384
385         nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
386
387         return nb_total, nb_correct
388
389     def seq2str(self, seq):
390         return " | ".join(
391             [
392                 " ".join(
393                     ["-".join([f"{x:02d}" if x < self.height * self.width else "**" for x in s]) for s in r.split(self.width)]
394                 )
395                 for r in seq.split(self.height * self.width)
396             ]
397         )
398
399
400 ####################
401
402 if __name__ == "__main__":
403     p = ProblemMixing()
404     s, m = p.generate_sequences(10000)
405     for x in s[:5]:
406         print(p.seq2str(x))
407     print(p.compute_nb_correct(None, None, s))