Update.
[mygptrnn.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25
26 ####################
27
28
29 class ProblemDegradation(Problem):
30     def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
31         assert value_max // nb_state_tokens >= 2
32         self.nb_state_tokens = nb_state_tokens
33         self.nb_time_steps = nb_time_steps
34         self.value_max = value_max
35         self.hard = hard
36
37     def generate_sequences(self, nb):
38         x = (
39             torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
40         ).long() * self.value_max
41         seq = [x]
42
43         for t in range(self.nb_time_steps - 1):
44             v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
45             u = (v.max(dim=-1, keepdim=True).values == v).long()
46             n = (
47                 (u * x)
48                 .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
49                 .sum(dim=-1, keepdim=True)
50             )
51             m = 1 + ((n - 1) * torch.rand(n.size())).long()
52             x = (
53                 x
54                 + m * u.roll(shifts=-1, dims=-1)
55                 - n * u
56                 + (n - m) * u.roll(shifts=1, dims=-1)
57             )
58             seq.append(x)
59
60         if self.hard:
61             seq.reverse()
62
63         seq = torch.cat(seq, dim=1)
64         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
65
66     def compute_nb_correct(self, input, ar_mask, result):
67         nb_total = result.size(0)
68         nb_correct = 0
69         e = result.new_zeros(self.nb_state_tokens)
70
71         for seq in result:
72             states = list(seq.split(self.nb_state_tokens))
73             if self.hard:
74                 states.reverse()
75
76             d = states[0]
77             j = d.sort(descending=True).indices[0]
78             e.zero_()
79             e[j] = self.value_max
80             if (d - e).abs().sum() == 0:
81                 nb_errors = 0
82                 for k in range(len(states) - 1):
83                     d = states[k + 1] - states[k]
84                     j = d.sort(descending=False).indices[0]
85                     if (
86                         d[j] == 0
87                         or d[j] > self.value_max // 4
88                         or d[(j + 1) % e.size(0)] <= 0
89                         or d[(j + 1) % e.size(0)] >= -d[j]
90                     ):
91                         nb_errors += 1
92                     else:
93                         e.zero_()
94                         e[j] = d[j]
95                         e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
96                         e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
97                         if (d - e).abs().sum() > 0:
98                             nb_errors += 1
99                 if nb_errors == 0:
100                     nb_correct += 1
101
102         return nb_total, nb_correct
103
104     def seq2str(self, seq):
105         return " | ".join(
106             [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
107         )
108
109
110 ####################
111
112
113 class ProblemMemory(Problem):
114     def __init__(self, len_total=32):
115         self.len_total = len_total
116         self.max_len_pattern = 5
117         self.nb_noise_tokens = 10
118         self.start_pattern_token = 0
119         self.end_pattern_token = 1
120         self.start_result_token = 2
121         self.end_result_token = 3
122         self.token_string = "[]<>" + "".join(
123             [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
124         )
125
126     def generate_sequences(self, nb):
127         sequences = (
128             torch.randint(self.nb_noise_tokens, (nb, self.len_total))
129             + self.end_result_token
130             + 1
131         )
132         len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
133         pattern_positions = torch.randint(
134             self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
135         )
136         k = self.len_total - (3 + self.max_len_pattern)
137         for i in range(nb):
138             l = len_patterns[i]
139             j = pattern_positions[i]
140             sequences[i, j] = self.start_pattern_token
141             sequences[i, j + l + 2] = self.end_pattern_token
142             sequences[i, k] = self.start_result_token
143             sequences[i, k + l + 2] = self.end_result_token
144             sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
145
146         j = torch.arange(self.len_total)[None, :]
147         ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
148
149         return sequences, ar_mask
150
151     def seq2str(self, seq):
152         def decode(x):
153             if x < len(self.token_string):
154                 return self.token_string[x]
155             else:
156                 return "?"
157
158         return "".join(decode(x.item()) for x in seq)
159
160
161 class ProblemTwoTargets(Problem):
162     def __init__(self, len_total=10, len_targets=3):
163         assert len_targets >= 3
164         assert len_total >= 3 * len_targets - 1
165         self.len_total = len_total
166         self.len_targets = len_targets
167
168     def generate_sequences(self, nb):
169         k = torch.arange(self.len_total)[None, :]
170         s = torch.randint(10, (nb, self.len_total))
171         l = torch.rand(nb, self.len_total)
172         l = l * (k <= self.len_total - self.len_targets).long()
173         k1 = l.argmax(dim=1, keepdim=True)
174         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
175         s = s * m + 10 * (1 - m)
176         l = l * (
177             1
178             - (k + self.len_targets - 1 >= k1).long()
179             * (k < k1 + self.len_targets).long()
180         )
181         k2 = l.argmax(dim=1, keepdim=True)
182         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
183         s = s * m + 11 * (1 - m)
184         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
185         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
186         sequences = torch.cat(
187             (
188                 s,
189                 torch.full((nb, 1), 12),
190                 a1,
191                 torch.full((nb, 1), 12),
192                 a2,
193                 torch.full((nb, 1), 12),
194             ),
195             1,
196         )
197         ar_mask = (sequences == 12).long()
198         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
199         return sequences, ar_mask
200
201     def seq2str(self, seq):
202         return "".join("0123456789-+|"[x.item()] for x in seq)
203
204
205 ####################
206
207
208 class ProblemByHeart(Problem):
209     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
210         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
211         self.seq[:, len_prompt] = 10
212
213     def generate_sequences(self, nb):
214         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
215         ar_mask = (sequences == 10).long()
216         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
217         return sequences, ar_mask
218
219     def seq2str(self, seq):
220         return "".join("0123456789|"[x.item()] for x in seq)
221
222
223 ####################
224
225
226 class ProblemLearnOperator(Problem):
227     def __init__(self, nb_operators=100, len_source=6, len_result=9):
228         self.len_source = len_source
229         self.len_result = len_result
230         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
231         self.operators = F.one_hot(
232             torch.rand(nb_operators, len_result, len_source).argmax(-1),
233             num_classes=len_source,
234         )
235
236     def generate_sequences(self, nb):
237         nb_operators = torch.randint(self.operators.size(0), (nb,))
238         operators = self.operators[nb_operators]
239         nb_operators = (
240             nb_operators[:, None]
241             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
242         ) % 10
243         marker1 = torch.full((nb, 1), 10)
244         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
245         marker2 = torch.full((nb, 1), 11)
246         result = operators.bmm(source[:, :, None]).squeeze(-1)
247         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
248         ar_mask = (sequences == 11).long()
249         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
250         return sequences, ar_mask
251
252     def seq2str(self, seq):
253         return "".join("0123456789|>"[x.item()] for x in seq)
254
255
256 ####################
257
258
259 class ProblemGuessOperator(Problem):
260     def __init__(self, len_source=5, len_result=8):
261         self.len_source = len_source
262         self.len_result = len_result
263
264     def generate_sequences(self, nb):
265         operators = F.one_hot(
266             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
267             num_classes=self.len_source,
268         )
269         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
270         marker1 = torch.full((nb, 1), 10)
271         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
272         marker2 = torch.full((nb, 1), 11)
273         source2 = torch.randint(10, (nb, self.len_source))
274         marker3 = torch.full((nb, 1), 12)
275         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
276
277         sequences = torch.cat(
278             (source1, marker1, result1, marker2, source2, marker3, result2), 1
279         )
280         ar_mask = (sequences == 12).long()
281         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
282         return sequences, ar_mask
283
284     def seq2str(self, seq):
285         return "".join("0123456789>|~"[x.item()] for x in seq)
286
287
288 ####################
289
290
291 class ProblemAddition(Problem):
292     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
293         self.nb_digits = nb_digits
294         self.zero_padded = zero_padded
295         self.inverted_result = inverted_result
296         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
297         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
298
299     def tensorize(self, strings):
300         len_max = max([len(x) for x in strings])
301         return torch.cat(
302             [
303                 torch.tensor(
304                     [
305                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
306                         for s in strings
307                     ]
308                 )
309             ],
310             0,
311         )
312
313     def generate_sequences(self, nb):
314         sequences = []
315         for k in range(nb):
316             a, b = torch.randint(10**self.nb_digits, (2,))
317             c = a + b
318             a, b, c = str(a.item()), str(b.item()), str(c.item())
319             if self.zero_padded:
320                 a = "0" * (self.nb_digits - len(a)) + a
321                 b = "0" * (self.nb_digits - len(b)) + b
322                 c = "0" * (self.nb_digits + 1 - len(c)) + c
323             if self.inverted_result:
324                 c = c[::-1]
325             sequences.append(f"{a}+{b}={c}$")
326
327         sequences = self.tensorize(sequences)
328         ar_mask = (sequences == self.char2id["="]).long()
329         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
330         return sequences, ar_mask
331
332     def seq2str(self, seq):
333         return "".join(self.id2char[x.item()] for x in seq)
334
335
336 ####################
337
338
339 class ProblemMixing(Problem):
340     def __init__(
341         self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
342     ):
343         self.height = height
344         self.width = width
345         self.nb_time_steps = nb_time_steps
346         self.hard = hard
347         self.random_start = random_start
348
349     def start_random(self, nb):
350         y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
351
352         if self.random_start:
353             i = (
354                 torch.arange(self.height)
355                 .reshape(1, -1, 1)
356                 .expand(nb, self.height, self.width)
357             )
358             j = (
359                 torch.arange(self.width)
360                 .reshape(1, 1, -1)
361                 .expand(nb, self.height, self.width)
362             )
363
364             ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
365             rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
366
367             m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
368
369             y = y * m + self.height * self.width * (1 - m)
370
371         y = y.reshape(nb, self.height, self.width)
372
373         return y
374
375     def start_error(self, x):
376         if self.random_start:
377             i = (
378                 torch.arange(self.height, device=x.device)
379                 .reshape(1, -1, 1)
380                 .expand_as(x)
381             )
382             j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
383
384             ri = (
385                 (x == self.height * self.width)
386                 .long()
387                 .sum(dim=-1)
388                 .argmax(-1)
389                 .view(-1, 1, 1)
390             )
391             rj = (
392                 (x == self.height * self.width)
393                 .long()
394                 .sum(dim=-2)
395                 .argmax(-1)
396                 .view(-1, 1, 1)
397             )
398
399             m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
400         else:
401             m = 1
402
403         x = x.flatten(1)
404         u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
405
406         d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
407
408         return d
409
410     def moves(self, x):
411         y = (
412             x[:, None, :, :]
413             .expand(-1, self.height * 2 + self.width * 2, -1, -1)
414             .clone()
415         )
416         k = 0
417
418         for i in range(self.height):
419             y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
420             k += 1
421             y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
422             k += 1
423
424         for j in range(self.width):
425             y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
426             k += 1
427             y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
428             k += 1
429
430         return y
431
432     def generate_sequences(self, nb):
433         x = self.start_random(nb)
434
435         seq = [x.flatten(1)]
436
437         for t in range(self.nb_time_steps - 1):
438             y = self.moves(x)
439             x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
440             seq.append(x.flatten(1))
441
442         if self.hard:
443             seq.reverse()
444
445         seq = torch.cat(seq, dim=1)
446         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
447
448     def compute_nb_correct(self, input, ar_mask, result):
449         a = [
450             x.reshape(result.size(0), self.height, self.width)
451             for x in result.split(self.height * self.width, dim=1)
452         ]
453         if self.hard:
454             a.reverse()
455
456         x = a[0]
457
458         d = self.start_error(x)
459
460         for t in range(self.nb_time_steps - 1):
461             x0, x = a[t], a[t + 1]
462             y = self.moves(x0)
463             d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
464
465         nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
466
467         return nb_total, nb_correct
468
469     def seq2str(self, seq):
470         return " | ".join(
471             [
472                 " ".join(
473                     [
474                         "-".join(
475                             [
476                                 f"{x:02d}" if x < self.height * self.width else "**"
477                                 for x in s
478                             ]
479                         )
480                         for s in r.split(self.width)
481                     ]
482                 )
483                 for r in seq.split(self.height * self.width)
484             ]
485         )
486
487
488 ####################
489
490 if __name__ == "__main__":
491     p = ProblemMixing(height=3, width=3, random_start=False)
492
493     s, m = p.generate_sequences(10000)
494     for x in s[:5]:
495         print(p.seq2str(x))
496     print(p.compute_nb_correct(None, None, s))