819715e1b5b1bab6af7207f8656f6aaefb8408f0
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25 ####################
26
27
28 class ProblemDegradation(Problem):
29     def __init__(self, nb_state_tokens=5, nb_time_steps=5, value_max=25, hard=False):
30         self.nb_state_tokens = nb_state_tokens
31         self.nb_time_steps = nb_time_steps
32         self.value_max = value_max
33         self.hard = hard
34
35     def generate_sequences(self,nb):
36
37         x = (torch.rand(nb,self.nb_state_tokens).sort(dim=-1).indices == 0).long() * self.value_max
38         seq = [x]
39
40         for t in range(self.nb_time_steps-1):
41             v = torch.rand(x.size()) * (x > 0).float()
42             u = (v.max(dim=-1,keepdim=True).values == v).long()
43             n = (u*x*torch.rand(x.size())).long().sum(dim=-1,keepdim=True) // 2
44             x = x + n * (u.roll(shifts=-1,dims=-1) - 2 * u + u.roll(shifts=1,dims=-1))
45             seq.append(x)
46
47         if self.hard: seq.reverse()
48
49         seq = torch.cat(seq,dim=1)
50         return seq,seq.new_full(seq.size(), 1, dtype=torch.int64)
51
52     def compute_nb_correct(self, input, ar_mask, result):
53         nb_total = result.size(0)
54         nb_correct = 0
55         e=result.new_zeros(self.nb_state_tokens)
56
57         for seq in result:
58             states = list(seq.split(self.nb_state_tokens))
59             if self.hard:
60                 states.reverse()
61
62             d = states[0]
63             j=d.sort(descending=True).indices[0]
64             e.zero_()
65             e[j]=self.value_max
66             if (d-e).abs().sum() == 0:
67                 nb_errors = 0
68                 for k in range(len(states)-1):
69                     d=states[k]-states[k+1]
70                     j=d.sort(descending=True).indices[0]
71                     e.zero_()
72                     e[j]=d[j]
73                     e[(j+1)%e.size(0)]=-d[j]//2
74                     e[(j-1)%e.size(0)]=-d[j]//2
75                     if (d-e).abs().sum() > 0:
76                         nb_errors += 1
77                 if nb_errors == 0:
78                     nb_correct += 1
79
80         return nb_total, nb_correct
81
82     def seq2str(self, seq):
83         return " | ".join( [ " ".join([f"{x:02d}" for x in s ]) for s in seq.split(self.nb_state_tokens) ] )
84
85 ####################
86
87
88 class ProblemTwoTargets(Problem):
89     def __init__(self, len_total=10, len_targets=3):
90         assert len_targets >= 3
91         assert len_total >= 3 * len_targets - 1
92         self.len_total = len_total
93         self.len_targets = len_targets
94
95     def generate_sequences(self, nb):
96         k = torch.arange(self.len_total)[None, :]
97         s = torch.randint(10, (nb, self.len_total))
98         l = torch.rand(nb, self.len_total)
99         l = l * (k <= self.len_total - self.len_targets).long()
100         k1 = l.argmax(dim=1, keepdim=True)
101         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
102         s = s * m + 10 * (1 - m)
103         l = l * (
104             1
105             - (k + self.len_targets - 1 >= k1).long()
106             * (k < k1 + self.len_targets).long()
107         )
108         k2 = l.argmax(dim=1, keepdim=True)
109         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
110         s = s * m + 11 * (1 - m)
111         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
112         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
113         sequences = torch.cat(
114             (
115                 s,
116                 torch.full((nb, 1), 12),
117                 a1,
118                 torch.full((nb, 1), 12),
119                 a2,
120                 torch.full((nb, 1), 12),
121             ),
122             1,
123         )
124         ar_mask = (sequences == 12).long()
125         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
126         return sequences, ar_mask
127
128     def seq2str(self, seq):
129         return "".join("0123456789-+|"[x.item()] for x in seq)
130
131
132 ####################
133
134
135 class ProblemByHeart(Problem):
136     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
137         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
138         self.seq[:, len_prompt] = 10
139
140     def generate_sequences(self, nb):
141         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
142         ar_mask = (sequences == 10).long()
143         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
144         return sequences, ar_mask
145
146     def seq2str(self, seq):
147         return "".join("0123456789|"[x.item()] for x in seq)
148
149
150 ####################
151
152
153 class ProblemLearnOperator(Problem):
154     def __init__(self, nb_operators=100, len_source=6, len_result=9):
155         self.len_source = len_source
156         self.len_result = len_result
157         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
158         self.operators = F.one_hot(
159             torch.rand(nb_operators, len_result, len_source).argmax(-1),
160             num_classes=len_source,
161         )
162
163     def generate_sequences(self, nb):
164         nb_operators = torch.randint(self.operators.size(0), (nb,))
165         operators = self.operators[nb_operators]
166         nb_operators = (
167             nb_operators[:, None]
168             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
169         ) % 10
170         marker1 = torch.full((nb, 1), 10)
171         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
172         marker2 = torch.full((nb, 1), 11)
173         result = operators.bmm(source[:, :, None]).squeeze(-1)
174         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
175         ar_mask = (sequences == 11).long()
176         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
177         return sequences, ar_mask
178
179     def seq2str(self, seq):
180         return "".join("0123456789|>"[x.item()] for x in seq)
181
182
183 ####################
184
185
186 class ProblemGuessOperator(Problem):
187     def __init__(self, len_source=5, len_result=8):
188         self.len_source = len_source
189         self.len_result = len_result
190
191     def generate_sequences(self, nb):
192         operators = F.one_hot(
193             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
194             num_classes=self.len_source,
195         )
196         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
197         marker1 = torch.full((nb, 1), 10)
198         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
199         marker2 = torch.full((nb, 1), 11)
200         source2 = torch.randint(10, (nb, self.len_source))
201         marker3 = torch.full((nb, 1), 12)
202         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
203
204         sequences = torch.cat(
205             (source1, marker1, result1, marker2, source2, marker3, result2), 1
206         )
207         ar_mask = (sequences == 12).long()
208         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
209         return sequences, ar_mask
210
211     def seq2str(self, seq):
212         return "".join("0123456789>|~"[x.item()] for x in seq)
213
214
215 ####################
216
217
218 class ProblemAddition(Problem):
219     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
220         self.nb_digits = nb_digits
221         self.zero_padded = zero_padded
222         self.inverted_result = inverted_result
223         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
224         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
225
226     def tensorize(self, strings):
227         len_max = max([len(x) for x in strings])
228         return torch.cat(
229             [
230                 torch.tensor(
231                     [
232                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
233                         for s in strings
234                     ]
235                 )
236             ],
237             0,
238         )
239
240     def generate_sequences(self, nb):
241         sequences = []
242         for k in range(nb):
243             a, b = torch.randint(10**self.nb_digits, (2,))
244             c = a + b
245             a, b, c = str(a.item()), str(b.item()), str(c.item())
246             if self.zero_padded:
247                 a = "0" * (self.nb_digits - len(a)) + a
248                 b = "0" * (self.nb_digits - len(b)) + b
249                 c = "0" * (self.nb_digits + 1 - len(c)) + c
250             if self.inverted_result:
251                 c = c[::-1]
252             sequences.append(f"{a}+{b}={c}$")
253
254         sequences = self.tensorize(sequences)
255         ar_mask = (sequences == self.char2id["="]).long()
256         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
257         return sequences, ar_mask
258
259     def seq2str(self, seq):
260         return "".join(self.id2char[x.item()] for x in seq)
261
262
263 if __name__ == "__main__":
264     p = ProblemDegradation(hard=False)
265     s, m = p.generate_sequences(10000)
266     print(p.seq2str(s[0]))
267     print(p.compute_nb_correct(None, None, s))