Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemTwoTargets(Problem):
25     def __init__(self, len_total=10, len_targets=3):
26         assert len_targets >= 3
27         assert len_total >= 3 * len_targets - 1
28         self.len_total = len_total
29         self.len_targets = len_targets
30
31     def generate_sequences(self, nb):
32         k = torch.arange(self.len_total)[None, :]
33         s = torch.randint(10, (nb, self.len_total))
34         l = torch.rand(nb, self.len_total)
35         l = l * (k <= self.len_total - self.len_targets).long()
36         k1 = l.argmax(dim=1, keepdim=True)
37         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
38         s = s * m + 10 * (1 - m)
39         l = l * (
40             1
41             - (k + self.len_targets - 1 >= k1).long()
42             * (k < k1 + self.len_targets).long()
43         )
44         k2 = l.argmax(dim=1, keepdim=True)
45         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
46         s = s * m + 11 * (1 - m)
47         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
48         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
49         sequences = torch.cat(
50             (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
51         )
52         ar_mask = (sequences == 12).long()
53         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
54         return sequences, ar_mask
55
56     def seq2str(self, seq):
57         return "".join("0123456789+-|"[x.item()] for x in seq)
58
59
60 ####################
61
62
63 class ProblemLenId(Problem):
64     def __init__(self, len_max=10):
65         self.len_max = len_max
66
67     def generate_sequences(self, nb):
68         k = torch.arange(self.len_max * 3 + 3)[None, :]
69         l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
70         i = torch.randint(10, (2, nb))[:, :, None]
71         a = l[0]
72         b = l[0] + 1 + l[1]
73         c = l[0] + 1 + l[1] + 1 + l[0]
74         sequences = (
75             (k < a) * i[0]
76             + (k == a) * 10
77             + (k > a) * (k < b) * i[1]
78             + (k == b) * 11
79             + (k > b) * (k < c) * i[1]
80             + (k >= c) * 12
81         )
82         ar_mask = (sequences == 11).long()
83         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
84         return sequences, ar_mask
85
86     def seq2str(self, seq):
87         return "".join("0123456789|>_"[x.item()] for x in seq)
88
89
90 ####################
91
92
93 class ProblemLevel0(Problem):
94     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
95         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
96         self.seq[:, len_prompt] = 10
97
98     def generate_sequences(self, nb):
99         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
100         ar_mask = (sequences == 10).long()
101         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
102         return sequences, ar_mask
103
104     def seq2str(self, seq):
105         return "".join("0123456789|"[x.item()] for x in seq)
106
107
108 ####################
109
110
111 class ProblemLevel1(Problem):
112     def __init__(self, nb_operators=100, len_source=5, len_result=8):
113         self.len_source = len_source
114         self.len_result = len_result
115         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
116         self.operators = F.one_hot(
117             torch.rand(nb_operators, len_result, len_source).argmax(-1),
118             num_classes=len_source,
119         )
120
121     def generate_sequences(self, nb):
122         nb_operators = torch.randint(self.operators.size(0), (nb,))
123         operators = self.operators[nb_operators]
124         nb_operators = (
125             nb_operators[:, None]
126             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
127         ) % 10
128         marker1 = torch.full((nb, 1), 10)
129         # source = torch.randint(10, (nb, self.len_source))
130         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
131         marker2 = torch.full((nb, 1), 11)
132         result = operators.bmm(source[:, :, None]).squeeze(-1)
133         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
134         ar_mask = (sequences == 11).long()
135         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
136         return sequences, ar_mask
137
138     def seq2str(self, seq):
139         return "".join("0123456789|>"[x.item()] for x in seq)
140
141
142 ####################
143
144
145 class ProblemLevel2(Problem):
146     def __init__(self, len_source=5, len_result=8):
147         self.len_source = len_source
148         self.len_result = len_result
149
150     def generate_sequences(self, nb):
151         operators = F.one_hot(
152             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
153             num_classes=self.len_source,
154         )
155         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
156         marker1 = torch.full((nb, 1), 10)
157         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
158         marker2 = torch.full((nb, 1), 11)
159         source2 = torch.randint(10, (nb, self.len_source))
160         marker3 = torch.full((nb, 1), 12)
161         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
162
163         sequences = torch.cat(
164             (source1, marker1, result1, marker2, source2, marker3, result2), 1
165         )
166         ar_mask = (sequences == 12).long()
167         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
168         return sequences, ar_mask
169
170     def seq2str(self, seq):
171         return "".join("0123456789>|~"[x.item()] for x in seq)
172
173
174 ####################
175
176
177 class ProblemAddition(Problem):
178     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
179         self.nb_digits = nb_digits
180         self.zero_padded = zero_padded
181         self.inverted_result = inverted_result
182         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
183         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
184
185     def tensorize(self, strings):
186         len_max = max([len(x) for x in strings])
187         return torch.cat(
188             [
189                 torch.tensor(
190                     [
191                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
192                         for s in strings
193                     ]
194                 )
195             ],
196             0,
197         )
198
199     def generate_sequences(self, nb):
200         sequences = []
201         for k in range(nb):
202             a, b = torch.randint(10**self.nb_digits, (2,))
203             c = a + b
204             a, b, c = str(a.item()), str(b.item()), str(c.item())
205             if self.zero_padded:
206                 a = "0" * (self.nb_digits - len(a)) + a
207                 b = "0" * (self.nb_digits - len(b)) + b
208                 c = "0" * (self.nb_digits + 1 - len(c)) + c
209             if self.inverted_result:
210                 c = c[::-1]
211             sequences.append(f"{a}+{b}={c}$")
212
213         sequences = self.tensorize(sequences)
214         ar_mask = (sequences == self.char2id["="]).long()
215         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
216         return sequences, ar_mask
217
218     def seq2str(self, seq):
219         return "".join(self.id2char[x.item()] for x in seq)
220
221
222 if __name__ == "__main__":
223     p = ProblemTwoTargets(12, 4)
224     s, m = p.generate_sequences(10)
225     for x in s:
226         print(p.seq2str(x))