Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25
26 ####################
27
28
29 class ProblemDegradation(Problem):
30     def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
31         assert value_max // nb_state_tokens >= 2
32         self.nb_state_tokens = nb_state_tokens
33         self.nb_time_steps = nb_time_steps
34         self.value_max = value_max
35         self.hard = hard
36
37     def generate_sequences(self, nb):
38         x = (
39             torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
40         ).long() * self.value_max
41         seq = [x]
42
43         for t in range(self.nb_time_steps - 1):
44             v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
45             u = (v.max(dim=-1, keepdim=True).values == v).long()
46             n = (
47                 (u * x)
48                 .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
49                 .sum(dim=-1, keepdim=True)
50             )
51             m = 1 + ((n - 1) * torch.rand(n.size())).long()
52             x = (
53                 x
54                 + m * u.roll(shifts=-1, dims=-1)
55                 - n * u
56                 + (n - m) * u.roll(shifts=1, dims=-1)
57             )
58             seq.append(x)
59
60         if self.hard:
61             seq.reverse()
62
63         seq = torch.cat(seq, dim=1)
64         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
65
66     def compute_nb_correct(self, input, ar_mask, result):
67         nb_total = result.size(0)
68         nb_correct = 0
69         e = result.new_zeros(self.nb_state_tokens)
70
71         for seq in result:
72             states = list(seq.split(self.nb_state_tokens))
73             if self.hard:
74                 states.reverse()
75
76             d = states[0]
77             j = d.sort(descending=True).indices[0]
78             e.zero_()
79             e[j] = self.value_max
80             if (d - e).abs().sum() == 0:
81                 nb_errors = 0
82                 for k in range(len(states) - 1):
83                     d = states[k + 1] - states[k]
84                     j = d.sort(descending=False).indices[0]
85                     if (
86                         d[j] == 0
87                         or d[j] > self.value_max // 4
88                         or d[(j + 1) % e.size(0)] <= 0
89                         or d[(j + 1) % e.size(0)] >= -d[j]
90                     ):
91                         nb_errors += 1
92                     else:
93                         e.zero_()
94                         e[j] = d[j]
95                         e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
96                         e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
97                         if (d - e).abs().sum() > 0:
98                             nb_errors += 1
99                 if nb_errors == 0:
100                     nb_correct += 1
101
102         return nb_total, nb_correct
103
104     def seq2str(self, seq):
105         return " | ".join(
106             [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
107         )
108
109
110 ####################
111
112
113 class ProblemTwoTargets(Problem):
114     def __init__(self, len_total=10, len_targets=3):
115         assert len_targets >= 3
116         assert len_total >= 3 * len_targets - 1
117         self.len_total = len_total
118         self.len_targets = len_targets
119
120     def generate_sequences(self, nb):
121         k = torch.arange(self.len_total)[None, :]
122         s = torch.randint(10, (nb, self.len_total))
123         l = torch.rand(nb, self.len_total)
124         l = l * (k <= self.len_total - self.len_targets).long()
125         k1 = l.argmax(dim=1, keepdim=True)
126         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
127         s = s * m + 10 * (1 - m)
128         l = l * (
129             1
130             - (k + self.len_targets - 1 >= k1).long()
131             * (k < k1 + self.len_targets).long()
132         )
133         k2 = l.argmax(dim=1, keepdim=True)
134         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
135         s = s * m + 11 * (1 - m)
136         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
137         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
138         sequences = torch.cat(
139             (
140                 s,
141                 torch.full((nb, 1), 12),
142                 a1,
143                 torch.full((nb, 1), 12),
144                 a2,
145                 torch.full((nb, 1), 12),
146             ),
147             1,
148         )
149         ar_mask = (sequences == 12).long()
150         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
151         return sequences, ar_mask
152
153     def seq2str(self, seq):
154         return "".join("0123456789-+|"[x.item()] for x in seq)
155
156
157 ####################
158
159
160 class ProblemByHeart(Problem):
161     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
162         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
163         self.seq[:, len_prompt] = 10
164
165     def generate_sequences(self, nb):
166         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
167         ar_mask = (sequences == 10).long()
168         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
169         return sequences, ar_mask
170
171     def seq2str(self, seq):
172         return "".join("0123456789|"[x.item()] for x in seq)
173
174
175 ####################
176
177
178 class ProblemLearnOperator(Problem):
179     def __init__(self, nb_operators=100, len_source=6, len_result=9):
180         self.len_source = len_source
181         self.len_result = len_result
182         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
183         self.operators = F.one_hot(
184             torch.rand(nb_operators, len_result, len_source).argmax(-1),
185             num_classes=len_source,
186         )
187
188     def generate_sequences(self, nb):
189         nb_operators = torch.randint(self.operators.size(0), (nb,))
190         operators = self.operators[nb_operators]
191         nb_operators = (
192             nb_operators[:, None]
193             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
194         ) % 10
195         marker1 = torch.full((nb, 1), 10)
196         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
197         marker2 = torch.full((nb, 1), 11)
198         result = operators.bmm(source[:, :, None]).squeeze(-1)
199         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
200         ar_mask = (sequences == 11).long()
201         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
202         return sequences, ar_mask
203
204     def seq2str(self, seq):
205         return "".join("0123456789|>"[x.item()] for x in seq)
206
207
208 ####################
209
210
211 class ProblemGuessOperator(Problem):
212     def __init__(self, len_source=5, len_result=8):
213         self.len_source = len_source
214         self.len_result = len_result
215
216     def generate_sequences(self, nb):
217         operators = F.one_hot(
218             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
219             num_classes=self.len_source,
220         )
221         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
222         marker1 = torch.full((nb, 1), 10)
223         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
224         marker2 = torch.full((nb, 1), 11)
225         source2 = torch.randint(10, (nb, self.len_source))
226         marker3 = torch.full((nb, 1), 12)
227         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
228
229         sequences = torch.cat(
230             (source1, marker1, result1, marker2, source2, marker3, result2), 1
231         )
232         ar_mask = (sequences == 12).long()
233         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
234         return sequences, ar_mask
235
236     def seq2str(self, seq):
237         return "".join("0123456789>|~"[x.item()] for x in seq)
238
239
240 ####################
241
242
243 class ProblemAddition(Problem):
244     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
245         self.nb_digits = nb_digits
246         self.zero_padded = zero_padded
247         self.inverted_result = inverted_result
248         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
249         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
250
251     def tensorize(self, strings):
252         len_max = max([len(x) for x in strings])
253         return torch.cat(
254             [
255                 torch.tensor(
256                     [
257                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
258                         for s in strings
259                     ]
260                 )
261             ],
262             0,
263         )
264
265     def generate_sequences(self, nb):
266         sequences = []
267         for k in range(nb):
268             a, b = torch.randint(10**self.nb_digits, (2,))
269             c = a + b
270             a, b, c = str(a.item()), str(b.item()), str(c.item())
271             if self.zero_padded:
272                 a = "0" * (self.nb_digits - len(a)) + a
273                 b = "0" * (self.nb_digits - len(b)) + b
274                 c = "0" * (self.nb_digits + 1 - len(c)) + c
275             if self.inverted_result:
276                 c = c[::-1]
277             sequences.append(f"{a}+{b}={c}$")
278
279         sequences = self.tensorize(sequences)
280         ar_mask = (sequences == self.char2id["="]).long()
281         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
282         return sequences, ar_mask
283
284     def seq2str(self, seq):
285         return "".join(self.id2char[x.item()] for x in seq)
286
287
288 if __name__ == "__main__":
289     p = ProblemDegradation(hard=False)
290     s, m = p.generate_sequences(10000)
291     for x in s[:100]:
292         print(p.seq2str(x))
293     print(p.compute_nb_correct(None, None, s))