Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemTwoTargets(Problem):
25     def __init__(self, len_total=10, len_targets=3):
26         assert len_targets >= 3
27         assert len_total >= 3 * len_targets - 1
28         self.len_total = len_total
29         self.len_targets = len_targets
30
31     def generate_sequences(self, nb):
32         k = torch.arange(self.len_total)[None, :]
33         s = torch.randint(10, (nb, self.len_total))
34         l = torch.rand(nb, self.len_total)
35         l = l * (k <= self.len_total - self.len_targets).long()
36         k1 = l.argmax(dim=1, keepdim=True)
37         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
38         s = s * m + 10 * (1 - m)
39         l = l * (
40             1
41             - (k + self.len_targets - 1 >= k1).long()
42             * (k < k1 + self.len_targets).long()
43         )
44         k2 = l.argmax(dim=1, keepdim=True)
45         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
46         s = s * m + 11 * (1 - m)
47         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
48         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
49         sequences = torch.cat(
50             (
51                 s,
52                 torch.full((nb, 1), 12),
53                 a1,
54                 torch.full((nb, 1), 12),
55                 a2,
56                 torch.full((nb, 1), 12),
57             ),
58             1,
59         )
60         ar_mask = (sequences == 12).long()
61         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
62         return sequences, ar_mask
63
64     def seq2str(self, seq):
65         return "".join("0123456789-+|"[x.item()] for x in seq)
66
67
68 ####################
69
70
71 class ProblemByHeart(Problem):
72     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
73         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
74         self.seq[:, len_prompt] = 10
75
76     def generate_sequences(self, nb):
77         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
78         ar_mask = (sequences == 10).long()
79         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
80         return sequences, ar_mask
81
82     def seq2str(self, seq):
83         return "".join("0123456789|"[x.item()] for x in seq)
84
85
86 ####################
87
88
89 class ProblemLearnOperator(Problem):
90     def __init__(self, nb_operators=100, len_source=5, len_result=8):
91         self.len_source = len_source
92         self.len_result = len_result
93         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
94         self.operators = F.one_hot(
95             torch.rand(nb_operators, len_result, len_source).argmax(-1),
96             num_classes=len_source,
97         )
98
99     def generate_sequences(self, nb):
100         nb_operators = torch.randint(self.operators.size(0), (nb,))
101         operators = self.operators[nb_operators]
102         nb_operators = (
103             nb_operators[:, None]
104             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
105         ) % 10
106         marker1 = torch.full((nb, 1), 10)
107         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
108         marker2 = torch.full((nb, 1), 11)
109         result = operators.bmm(source[:, :, None]).squeeze(-1)
110         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
111         ar_mask = (sequences == 11).long()
112         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
113         return sequences, ar_mask
114
115     def seq2str(self, seq):
116         return "".join("0123456789|>"[x.item()] for x in seq)
117
118
119 ####################
120
121
122 class ProblemGuessOperator(Problem):
123     def __init__(self, len_source=5, len_result=8):
124         self.len_source = len_source
125         self.len_result = len_result
126
127     def generate_sequences(self, nb):
128         operators = F.one_hot(
129             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
130             num_classes=self.len_source,
131         )
132         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
133         marker1 = torch.full((nb, 1), 10)
134         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
135         marker2 = torch.full((nb, 1), 11)
136         source2 = torch.randint(10, (nb, self.len_source))
137         marker3 = torch.full((nb, 1), 12)
138         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
139
140         sequences = torch.cat(
141             (source1, marker1, result1, marker2, source2, marker3, result2), 1
142         )
143         ar_mask = (sequences == 12).long()
144         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
145         return sequences, ar_mask
146
147     def seq2str(self, seq):
148         return "".join("0123456789>|~"[x.item()] for x in seq)
149
150
151 ####################
152
153
154 class ProblemAddition(Problem):
155     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
156         self.nb_digits = nb_digits
157         self.zero_padded = zero_padded
158         self.inverted_result = inverted_result
159         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
160         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
161
162     def tensorize(self, strings):
163         len_max = max([len(x) for x in strings])
164         return torch.cat(
165             [
166                 torch.tensor(
167                     [
168                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
169                         for s in strings
170                     ]
171                 )
172             ],
173             0,
174         )
175
176     def generate_sequences(self, nb):
177         sequences = []
178         for k in range(nb):
179             a, b = torch.randint(10**self.nb_digits, (2,))
180             c = a + b
181             a, b, c = str(a.item()), str(b.item()), str(c.item())
182             if self.zero_padded:
183                 a = "0" * (self.nb_digits - len(a)) + a
184                 b = "0" * (self.nb_digits - len(b)) + b
185                 c = "0" * (self.nb_digits + 1 - len(c)) + c
186             if self.inverted_result:
187                 c = c[::-1]
188             sequences.append(f"{a}+{b}={c}$")
189
190         sequences = self.tensorize(sequences)
191         ar_mask = (sequences == self.char2id["="]).long()
192         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
193         return sequences, ar_mask
194
195     def seq2str(self, seq):
196         return "".join(self.id2char[x.item()] for x in seq)
197
198
199 if __name__ == "__main__":
200     p = ProblemTwoTargets(12, 4)
201     s, m = p.generate_sequences(10)
202     for x in s:
203         print(p.seq2str(x))