68a46b3cdb89dc79f892b5d3339927c039193e6a
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25 ####################
26
27
28 class ProblemTwoCuts(Problem):
29     def __init__(self, len_total=50, nb_values=100, global_constraint=True):
30         self.len_total = len_total
31         self.nb_values = nb_values
32         self.global_constraint = global_constraint
33
34     def generate_sequences_internal(self, nb):
35         return u,v,a,b,c
36
37     def generate_sequences(self,nb):
38
39         u = torch.randint(self.len_total, (nb,))
40         v = torch.randint(self.len_total, (nb,))
41
42         a = torch.randint(self.nb_values, (nb,))
43         b = torch.randint(self.nb_values, (nb,))
44         c = torch.randint(self.nb_values, (nb,))
45
46         while True:
47             to_compute = torch.logical_or(u>=v-self.len_total//10,u<v-self.len_total//5)
48             to_compute =torch.logical_or(to_compute, u == 0)
49             to_compute =torch.logical_or(to_compute, v == self.len_total)
50             n = to_compute.long().sum()
51             if n == 0:
52                 break
53             else:
54                 u[to_compute] = torch.randint(self.len_total, (n,))
55                 v[to_compute] = torch.randint(self.len_total, (n,))
56
57         while True:
58             to_compute = a==b
59             to_compute = torch.logical_or(to_compute,b==c)
60             to_compute = torch.logical_or(to_compute,a==c)
61
62             if self.global_constraint:
63                 to_compute = torch.logical_or(to_compute,(a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total != self.nb_values//2)
64
65             n = to_compute.long().sum()
66             if n == 0:
67                 break
68             else:
69                 a[to_compute] = torch.randint(self.nb_values, (n,))
70                 b[to_compute] = torch.randint(self.nb_values, (n,))
71                 c[to_compute] = torch.randint(self.nb_values, (n,))
72
73         assert (u>=v).long().sum() == 0
74         assert (a==b).long().sum() == 0
75         assert (a==c).long().sum() == 0
76         assert (c==b).long().sum() == 0
77
78         t = torch.arange(self.len_total)
79         seq = (t[None,:] < u[:,None]).long() * a[:,None] + \
80             (t[None,:] >= u[:,None]).long() * (t[None,:] < v[:,None]).long() * b[:,None] + \
81             (t[None,:] >= v[:,None]).long() * c[:,None]
82
83         return seq,seq.new_full(seq.size(), 1, dtype=torch.int64)
84
85     def compute_nb_correct(self, input, ar_mask, result):
86         nb_total = result.size(0)
87         nb_correct = 0
88         i = torch.arange(result.size(1), device=result.device)
89
90         for k in range(nb_total):
91             s = result[k]
92             a = s[0]
93             uu = (s != a).nonzero()
94             if uu.size(0) > 0:
95                 u = uu.min()
96                 b = s[u]
97                 vv = torch.logical_and(s != b, i >= u).nonzero()
98                 if vv.size(0) > 0:
99                     v = vv.min()
100                     c = s[v]
101                     ww = torch.logical_and(s != c, i >= v).nonzero()
102                     if ww.size(0) == 0:
103                         if not self.global_constraint or (a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total == self.nb_values//2:
104                             nb_correct += 1
105
106         return nb_total, nb_correct
107
108     def seq2str(self, seq):
109         return " ".join( [ f"{x:02d}" for x in seq ] )
110
111 ####################
112
113
114 class ProblemTwoTargets(Problem):
115     def __init__(self, len_total=10, len_targets=3):
116         assert len_targets >= 3
117         assert len_total >= 3 * len_targets - 1
118         self.len_total = len_total
119         self.len_targets = len_targets
120
121     def generate_sequences(self, nb):
122         k = torch.arange(self.len_total)[None, :]
123         s = torch.randint(10, (nb, self.len_total))
124         l = torch.rand(nb, self.len_total)
125         l = l * (k <= self.len_total - self.len_targets).long()
126         k1 = l.argmax(dim=1, keepdim=True)
127         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
128         s = s * m + 10 * (1 - m)
129         l = l * (
130             1
131             - (k + self.len_targets - 1 >= k1).long()
132             * (k < k1 + self.len_targets).long()
133         )
134         k2 = l.argmax(dim=1, keepdim=True)
135         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
136         s = s * m + 11 * (1 - m)
137         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
138         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
139         sequences = torch.cat(
140             (
141                 s,
142                 torch.full((nb, 1), 12),
143                 a1,
144                 torch.full((nb, 1), 12),
145                 a2,
146                 torch.full((nb, 1), 12),
147             ),
148             1,
149         )
150         ar_mask = (sequences == 12).long()
151         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
152         return sequences, ar_mask
153
154     def seq2str(self, seq):
155         return "".join("0123456789-+|"[x.item()] for x in seq)
156
157
158 ####################
159
160
161 class ProblemByHeart(Problem):
162     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
163         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
164         self.seq[:, len_prompt] = 10
165
166     def generate_sequences(self, nb):
167         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
168         ar_mask = (sequences == 10).long()
169         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
170         return sequences, ar_mask
171
172     def seq2str(self, seq):
173         return "".join("0123456789|"[x.item()] for x in seq)
174
175
176 ####################
177
178
179 class ProblemLearnOperator(Problem):
180     def __init__(self, nb_operators=100, len_source=6, len_result=9):
181         self.len_source = len_source
182         self.len_result = len_result
183         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
184         self.operators = F.one_hot(
185             torch.rand(nb_operators, len_result, len_source).argmax(-1),
186             num_classes=len_source,
187         )
188
189     def generate_sequences(self, nb):
190         nb_operators = torch.randint(self.operators.size(0), (nb,))
191         operators = self.operators[nb_operators]
192         nb_operators = (
193             nb_operators[:, None]
194             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
195         ) % 10
196         marker1 = torch.full((nb, 1), 10)
197         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
198         marker2 = torch.full((nb, 1), 11)
199         result = operators.bmm(source[:, :, None]).squeeze(-1)
200         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
201         ar_mask = (sequences == 11).long()
202         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
203         return sequences, ar_mask
204
205     def seq2str(self, seq):
206         return "".join("0123456789|>"[x.item()] for x in seq)
207
208
209 ####################
210
211
212 class ProblemGuessOperator(Problem):
213     def __init__(self, len_source=5, len_result=8):
214         self.len_source = len_source
215         self.len_result = len_result
216
217     def generate_sequences(self, nb):
218         operators = F.one_hot(
219             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
220             num_classes=self.len_source,
221         )
222         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
223         marker1 = torch.full((nb, 1), 10)
224         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
225         marker2 = torch.full((nb, 1), 11)
226         source2 = torch.randint(10, (nb, self.len_source))
227         marker3 = torch.full((nb, 1), 12)
228         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
229
230         sequences = torch.cat(
231             (source1, marker1, result1, marker2, source2, marker3, result2), 1
232         )
233         ar_mask = (sequences == 12).long()
234         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
235         return sequences, ar_mask
236
237     def seq2str(self, seq):
238         return "".join("0123456789>|~"[x.item()] for x in seq)
239
240
241 ####################
242
243
244 class ProblemAddition(Problem):
245     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
246         self.nb_digits = nb_digits
247         self.zero_padded = zero_padded
248         self.inverted_result = inverted_result
249         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
250         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
251
252     def tensorize(self, strings):
253         len_max = max([len(x) for x in strings])
254         return torch.cat(
255             [
256                 torch.tensor(
257                     [
258                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
259                         for s in strings
260                     ]
261                 )
262             ],
263             0,
264         )
265
266     def generate_sequences(self, nb):
267         sequences = []
268         for k in range(nb):
269             a, b = torch.randint(10**self.nb_digits, (2,))
270             c = a + b
271             a, b, c = str(a.item()), str(b.item()), str(c.item())
272             if self.zero_padded:
273                 a = "0" * (self.nb_digits - len(a)) + a
274                 b = "0" * (self.nb_digits - len(b)) + b
275                 c = "0" * (self.nb_digits + 1 - len(c)) + c
276             if self.inverted_result:
277                 c = c[::-1]
278             sequences.append(f"{a}+{b}={c}$")
279
280         sequences = self.tensorize(sequences)
281         ar_mask = (sequences == self.char2id["="]).long()
282         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
283         return sequences, ar_mask
284
285     def seq2str(self, seq):
286         return "".join(self.id2char[x.item()] for x in seq)
287
288
289 if __name__ == "__main__":
290     p = ProblemTwoCuts(12)
291     s, m = p.generate_sequences(10000)
292     print(p.compute_nb_correct(None, None, s))