Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25 ####################
26
27
28 class ProblemDegradation(Problem):
29     def __init__(self, nb_state_tokens=5, nb_time_steps=5, value_max=25, hard=False):
30         self.nb_state_tokens = nb_state_tokens
31         self.nb_time_steps = nb_time_steps
32         self.value_max = value_max
33         self.hard = hard
34
35     def generate_sequences(self,nb):
36
37         x = (torch.rand(nb,self.nb_state_tokens).sort(dim=-1).indices == 0).long() * self.value_max
38         seq = [x]
39
40         for t in range(self.nb_time_steps-1):
41             v = torch.rand(x.size()) * (x > 0).float()
42             u = (v.max(dim=-1,keepdim=True).values == v).long()
43             n = (u*x*torch.rand(x.size())).long().sum(dim=-1,keepdim=True) // 2
44             x = x + n * (u.roll(shifts=-1,dims=-1) - 2 * u + u.roll(shifts=1,dims=-1))
45             seq.append(x)
46
47         if self.hard: seq.reverse()
48
49         seq = torch.cat(seq,dim=1)
50         return seq,seq.new_full(seq.size(), 1, dtype=torch.int64)
51
52     def compute_nb_correct(self, input, ar_mask, result):
53         nb_total = result.size(0)
54         nb_correct = 0
55
56         for seq in result:
57             states = list(seq.split(self.nb_state_tokens))
58             if self.hard:
59                 states.reverse()
60
61             d = states[0]
62             j=d.sort(descending=True).indices[0]
63             e=d.new_zeros(d.size())
64             e[j]=self.value_max
65             if (d-e).abs().sum() == 0:
66                 nb_errors = 0
67                 for k in range(len(states)-1):
68                     d=states[k]-states[k+1]
69                     j=d.sort(descending=True).indices[0]
70                     e=d.new_zeros(d.size())
71                     e[j]=d[j]
72                     e[(j+1)%e.size(0)]=-d[j]//2
73                     e[(j-1)%e.size(0)]=-d[j]//2
74                     if (d-e).abs().sum() > 0:
75                         nb_errors += 1
76                 if nb_errors == 0:
77                     nb_correct += 1
78
79         return nb_total, nb_correct
80
81     def seq2str(self, seq):
82         return " | ".join( [ " ".join([f"{x:02d}" for x in s ]) for s in seq.split(self.nb_state_tokens) ] )
83
84 ####################
85
86
87 class ProblemTwoTargets(Problem):
88     def __init__(self, len_total=10, len_targets=3):
89         assert len_targets >= 3
90         assert len_total >= 3 * len_targets - 1
91         self.len_total = len_total
92         self.len_targets = len_targets
93
94     def generate_sequences(self, nb):
95         k = torch.arange(self.len_total)[None, :]
96         s = torch.randint(10, (nb, self.len_total))
97         l = torch.rand(nb, self.len_total)
98         l = l * (k <= self.len_total - self.len_targets).long()
99         k1 = l.argmax(dim=1, keepdim=True)
100         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
101         s = s * m + 10 * (1 - m)
102         l = l * (
103             1
104             - (k + self.len_targets - 1 >= k1).long()
105             * (k < k1 + self.len_targets).long()
106         )
107         k2 = l.argmax(dim=1, keepdim=True)
108         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
109         s = s * m + 11 * (1 - m)
110         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
111         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
112         sequences = torch.cat(
113             (
114                 s,
115                 torch.full((nb, 1), 12),
116                 a1,
117                 torch.full((nb, 1), 12),
118                 a2,
119                 torch.full((nb, 1), 12),
120             ),
121             1,
122         )
123         ar_mask = (sequences == 12).long()
124         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
125         return sequences, ar_mask
126
127     def seq2str(self, seq):
128         return "".join("0123456789-+|"[x.item()] for x in seq)
129
130
131 ####################
132
133
134 class ProblemByHeart(Problem):
135     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
136         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
137         self.seq[:, len_prompt] = 10
138
139     def generate_sequences(self, nb):
140         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
141         ar_mask = (sequences == 10).long()
142         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
143         return sequences, ar_mask
144
145     def seq2str(self, seq):
146         return "".join("0123456789|"[x.item()] for x in seq)
147
148
149 ####################
150
151
152 class ProblemLearnOperator(Problem):
153     def __init__(self, nb_operators=100, len_source=6, len_result=9):
154         self.len_source = len_source
155         self.len_result = len_result
156         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
157         self.operators = F.one_hot(
158             torch.rand(nb_operators, len_result, len_source).argmax(-1),
159             num_classes=len_source,
160         )
161
162     def generate_sequences(self, nb):
163         nb_operators = torch.randint(self.operators.size(0), (nb,))
164         operators = self.operators[nb_operators]
165         nb_operators = (
166             nb_operators[:, None]
167             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
168         ) % 10
169         marker1 = torch.full((nb, 1), 10)
170         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
171         marker2 = torch.full((nb, 1), 11)
172         result = operators.bmm(source[:, :, None]).squeeze(-1)
173         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
174         ar_mask = (sequences == 11).long()
175         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
176         return sequences, ar_mask
177
178     def seq2str(self, seq):
179         return "".join("0123456789|>"[x.item()] for x in seq)
180
181
182 ####################
183
184
185 class ProblemGuessOperator(Problem):
186     def __init__(self, len_source=5, len_result=8):
187         self.len_source = len_source
188         self.len_result = len_result
189
190     def generate_sequences(self, nb):
191         operators = F.one_hot(
192             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
193             num_classes=self.len_source,
194         )
195         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
196         marker1 = torch.full((nb, 1), 10)
197         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
198         marker2 = torch.full((nb, 1), 11)
199         source2 = torch.randint(10, (nb, self.len_source))
200         marker3 = torch.full((nb, 1), 12)
201         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
202
203         sequences = torch.cat(
204             (source1, marker1, result1, marker2, source2, marker3, result2), 1
205         )
206         ar_mask = (sequences == 12).long()
207         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
208         return sequences, ar_mask
209
210     def seq2str(self, seq):
211         return "".join("0123456789>|~"[x.item()] for x in seq)
212
213
214 ####################
215
216
217 class ProblemAddition(Problem):
218     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
219         self.nb_digits = nb_digits
220         self.zero_padded = zero_padded
221         self.inverted_result = inverted_result
222         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
223         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
224
225     def tensorize(self, strings):
226         len_max = max([len(x) for x in strings])
227         return torch.cat(
228             [
229                 torch.tensor(
230                     [
231                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
232                         for s in strings
233                     ]
234                 )
235             ],
236             0,
237         )
238
239     def generate_sequences(self, nb):
240         sequences = []
241         for k in range(nb):
242             a, b = torch.randint(10**self.nb_digits, (2,))
243             c = a + b
244             a, b, c = str(a.item()), str(b.item()), str(c.item())
245             if self.zero_padded:
246                 a = "0" * (self.nb_digits - len(a)) + a
247                 b = "0" * (self.nb_digits - len(b)) + b
248                 c = "0" * (self.nb_digits + 1 - len(c)) + c
249             if self.inverted_result:
250                 c = c[::-1]
251             sequences.append(f"{a}+{b}={c}$")
252
253         sequences = self.tensorize(sequences)
254         ar_mask = (sequences == self.char2id["="]).long()
255         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
256         return sequences, ar_mask
257
258     def seq2str(self, seq):
259         return "".join(self.id2char[x.item()] for x in seq)
260
261
262 if __name__ == "__main__":
263     p = ProblemDegradation(hard=False)
264     s, m = p.generate_sequences(10000)
265     print(p.seq2str(s[0]))
266     print(p.compute_nb_correct(None, None, s))