2e0ca36a3803b629e17d78963a2096a3fc6347fc
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemTwoTargets(Problem):
25     def __init__(self, len_total=10, len_targets=3):
26         assert len_targets >= 3
27         assert len_total >= 3 * len_targets - 1
28         self.len_total = len_total
29         self.len_targets = len_targets
30
31     def generate_sequences(self, nb):
32         k = torch.arange(self.len_total)[None, :]
33         s = torch.randint(10, (nb, self.len_total))
34         l = torch.rand(nb, self.len_total)
35         l = l * (k <= self.len_total - self.len_targets).long()
36         k1 = l.argmax(dim=1, keepdim=True)
37         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
38         s = s * m + 10 * (1 - m)
39         l = l * (
40             1
41             - (k + self.len_targets - 1 >= k1).long()
42             * (k < k1 + self.len_targets).long()
43         )
44         k2 = l.argmax(dim=1, keepdim=True)
45         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
46         s = s * m + 11 * (1 - m)
47         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
48         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
49         sequences = torch.cat(
50             (
51                 s,
52                 torch.full((nb, 1), 12),
53                 a1,
54                 torch.full((nb, 1), 12),
55                 a2,
56                 torch.full((nb, 1), 12),
57             ),
58             1,
59         )
60         ar_mask = (sequences == 12).long()
61         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
62         return sequences, ar_mask
63
64     def seq2str(self, seq):
65         return "".join("0123456789-+|"[x.item()] for x in seq)
66
67
68 ####################
69
70
71 class ProblemLenId(Problem):
72     def __init__(self, len_max=10):
73         self.len_max = len_max
74
75     def generate_sequences(self, nb):
76         k = torch.arange(self.len_max * 3 + 3)[None, :]
77         l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
78         i = torch.randint(10, (2, nb))[:, :, None]
79         a = l[0]
80         b = l[0] + 1 + l[1]
81         c = l[0] + 1 + l[1] + 1 + l[0]
82         sequences = (
83             (k < a) * i[0]
84             + (k == a) * 10
85             + (k > a) * (k < b) * i[1]
86             + (k == b) * 11
87             + (k > b) * (k < c) * i[1]
88             + (k >= c) * 12
89         )
90         ar_mask = (sequences == 11).long()
91         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
92         return sequences, ar_mask
93
94     def seq2str(self, seq):
95         return "".join("0123456789|>_"[x.item()] for x in seq)
96
97
98 ####################
99
100
101 class ProblemLevel0(Problem):
102     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
103         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
104         self.seq[:, len_prompt] = 10
105
106     def generate_sequences(self, nb):
107         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
108         ar_mask = (sequences == 10).long()
109         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
110         return sequences, ar_mask
111
112     def seq2str(self, seq):
113         return "".join("0123456789|"[x.item()] for x in seq)
114
115
116 ####################
117
118
119 class ProblemLevel1(Problem):
120     def __init__(self, nb_operators=100, len_source=5, len_result=8):
121         self.len_source = len_source
122         self.len_result = len_result
123         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
124         self.operators = F.one_hot(
125             torch.rand(nb_operators, len_result, len_source).argmax(-1),
126             num_classes=len_source,
127         )
128
129     def generate_sequences(self, nb):
130         nb_operators = torch.randint(self.operators.size(0), (nb,))
131         operators = self.operators[nb_operators]
132         nb_operators = (
133             nb_operators[:, None]
134             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
135         ) % 10
136         marker1 = torch.full((nb, 1), 10)
137         # source = torch.randint(10, (nb, self.len_source))
138         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
139         marker2 = torch.full((nb, 1), 11)
140         result = operators.bmm(source[:, :, None]).squeeze(-1)
141         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
142         ar_mask = (sequences == 11).long()
143         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
144         return sequences, ar_mask
145
146     def seq2str(self, seq):
147         return "".join("0123456789|>"[x.item()] for x in seq)
148
149
150 ####################
151
152
153 class ProblemLevel2(Problem):
154     def __init__(self, len_source=5, len_result=8):
155         self.len_source = len_source
156         self.len_result = len_result
157
158     def generate_sequences(self, nb):
159         operators = F.one_hot(
160             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
161             num_classes=self.len_source,
162         )
163         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
164         marker1 = torch.full((nb, 1), 10)
165         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
166         marker2 = torch.full((nb, 1), 11)
167         source2 = torch.randint(10, (nb, self.len_source))
168         marker3 = torch.full((nb, 1), 12)
169         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
170
171         sequences = torch.cat(
172             (source1, marker1, result1, marker2, source2, marker3, result2), 1
173         )
174         ar_mask = (sequences == 12).long()
175         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
176         return sequences, ar_mask
177
178     def seq2str(self, seq):
179         return "".join("0123456789>|~"[x.item()] for x in seq)
180
181
182 ####################
183
184
185 class ProblemAddition(Problem):
186     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
187         self.nb_digits = nb_digits
188         self.zero_padded = zero_padded
189         self.inverted_result = inverted_result
190         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
191         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
192
193     def tensorize(self, strings):
194         len_max = max([len(x) for x in strings])
195         return torch.cat(
196             [
197                 torch.tensor(
198                     [
199                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
200                         for s in strings
201                     ]
202                 )
203             ],
204             0,
205         )
206
207     def generate_sequences(self, nb):
208         sequences = []
209         for k in range(nb):
210             a, b = torch.randint(10**self.nb_digits, (2,))
211             c = a + b
212             a, b, c = str(a.item()), str(b.item()), str(c.item())
213             if self.zero_padded:
214                 a = "0" * (self.nb_digits - len(a)) + a
215                 b = "0" * (self.nb_digits - len(b)) + b
216                 c = "0" * (self.nb_digits + 1 - len(c)) + c
217             if self.inverted_result:
218                 c = c[::-1]
219             sequences.append(f"{a}+{b}={c}$")
220
221         sequences = self.tensorize(sequences)
222         ar_mask = (sequences == self.char2id["="]).long()
223         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
224         return sequences, ar_mask
225
226     def seq2str(self, seq):
227         return "".join(self.id2char[x.item()] for x in seq)
228
229
230 if __name__ == "__main__":
231     p = ProblemTwoTargets(12, 4)
232     s, m = p.generate_sequences(10)
233     for x in s:
234         print(p.seq2str(x))