7b1d69859f6598fd6aee1cf832838678e2e5d377
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemTwoTargets(Problem):
25     def __init__(self, len_total=10, len_target=2):
26         assert len_total >= 3 * (2 + len_target) - 1
27         self.len_total = len_total
28         self.len_target = len_target
29
30     def generate_sequences(self, nb):
31         k = torch.arange(self.len_total)[None, :]
32         l = torch.randint(self.len_total, (2, nb))[:, :, None] + 1
33         i = torch.randint(10, (2, nb))[:, :, None]
34         a = l[0]
35         b = l[0] + 1 + l[1]
36         c = l[0] + 1 + l[1] + 1 + l[0]
37         sequences = (
38             (k < a) * i[0]
39             + (k == a) * 10
40             + (k > a) * (k < b) * i[1]
41             + (k == b) * 11
42             + (k > b) * (k < c) * i[1]
43             + (k >= c) * 12
44         )
45         ar_mask = (sequences == 11).long()
46         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
47         return sequences, ar_mask
48
49     def seq2str(self, seq):
50         return "".join("0123456789|>_"[x.item()] for x in seq)
51
52
53 ####################
54
55
56 class ProblemLenId(Problem):
57     def __init__(self, len_max=10):
58         self.len_max = len_max
59
60     def generate_sequences(self, nb):
61         k = torch.arange(self.len_max * 3 + 3)[None, :]
62         l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
63         i = torch.randint(10, (2, nb))[:, :, None]
64         a = l[0]
65         b = l[0] + 1 + l[1]
66         c = l[0] + 1 + l[1] + 1 + l[0]
67         sequences = (
68             (k < a) * i[0]
69             + (k == a) * 10
70             + (k > a) * (k < b) * i[1]
71             + (k == b) * 11
72             + (k > b) * (k < c) * i[1]
73             + (k >= c) * 12
74         )
75         ar_mask = (sequences == 11).long()
76         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
77         return sequences, ar_mask
78
79     def seq2str(self, seq):
80         return "".join("0123456789|>_"[x.item()] for x in seq)
81
82
83 ####################
84
85
86 class ProblemLevel0(Problem):
87     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
88         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
89         self.seq[:, len_prompt] = 10
90
91     def generate_sequences(self, nb):
92         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
93         ar_mask = (sequences == 10).long()
94         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
95         return sequences, ar_mask
96
97     def seq2str(self, seq):
98         return "".join("0123456789|"[x.item()] for x in seq)
99
100
101 ####################
102
103
104 class ProblemLevel1(Problem):
105     def __init__(self, nb_operators=100, len_source=5, len_result=8):
106         self.len_source = len_source
107         self.len_result = len_result
108         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
109         self.operators = F.one_hot(
110             torch.rand(nb_operators, len_result, len_source).argmax(-1),
111             num_classes=len_source,
112         )
113
114     def generate_sequences(self, nb):
115         nb_operators = torch.randint(self.operators.size(0), (nb,))
116         operators = self.operators[nb_operators]
117         nb_operators = (
118             nb_operators[:, None]
119             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
120         ) % 10
121         marker1 = torch.full((nb, 1), 10)
122         # source = torch.randint(10, (nb, self.len_source))
123         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
124         marker2 = torch.full((nb, 1), 11)
125         result = operators.bmm(source[:, :, None]).squeeze(-1)
126         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
127         ar_mask = (sequences == 11).long()
128         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
129         return sequences, ar_mask
130
131     def seq2str(self, seq):
132         return "".join("0123456789|>"[x.item()] for x in seq)
133
134
135 ####################
136
137
138 class ProblemLevel2(Problem):
139     def __init__(self, len_source=5, len_result=8):
140         self.len_source = len_source
141         self.len_result = len_result
142
143     def generate_sequences(self, nb):
144         operators = F.one_hot(
145             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
146             num_classes=self.len_source,
147         )
148         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
149         marker1 = torch.full((nb, 1), 10)
150         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
151         marker2 = torch.full((nb, 1), 11)
152         source2 = torch.randint(10, (nb, self.len_source))
153         marker3 = torch.full((nb, 1), 12)
154         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
155
156         sequences = torch.cat(
157             (source1, marker1, result1, marker2, source2, marker3, result2), 1
158         )
159         ar_mask = (sequences == 12).long()
160         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
161         return sequences, ar_mask
162
163     def seq2str(self, seq):
164         return "".join("0123456789>|~"[x.item()] for x in seq)
165
166
167 ####################
168
169
170 class ProblemAddition(Problem):
171     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
172         self.nb_digits = nb_digits
173         self.zero_padded = zero_padded
174         self.inverted_result = inverted_result
175         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
176         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
177
178     def tensorize(self, strings):
179         len_max = max([len(x) for x in strings])
180         return torch.cat(
181             [
182                 torch.tensor(
183                     [
184                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
185                         for s in strings
186                     ]
187                 )
188             ],
189             0,
190         )
191
192     def generate_sequences(self, nb):
193         sequences = []
194         for k in range(nb):
195             a, b = torch.randint(10**self.nb_digits, (2,))
196             c = a + b
197             a, b, c = str(a.item()), str(b.item()), str(c.item())
198             if self.zero_padded:
199                 a = "0" * (self.nb_digits - len(a)) + a
200                 b = "0" * (self.nb_digits - len(b)) + b
201                 c = "0" * (self.nb_digits + 1 - len(c)) + c
202             if self.inverted_result:
203                 c = c[::-1]
204             sequences.append(f"{a}+{b}={c}$")
205
206         sequences = self.tensorize(sequences)
207         ar_mask = (sequences == self.char2id["="]).long()
208         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
209         return sequences, ar_mask
210
211     def seq2str(self, seq):
212         return "".join(self.id2char[x.item()] for x in seq)
213
214
215 # class ProblemUnion(Problem):
216 # problems = [ProblemByheart()]
217 # nb_common_codes = 100
218
219 # def generate_sequences(nb_samples):
220 # problem_indexes = torch.randint(len(problems), (nb_samples,))
221 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
222 # print(f"{nb_samples_per_problem}")
223 # all_seq = []
224 # for nb, p in zip(nb_samples_per_problem, problems):
225 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
226 # return all_seq
227
228 # for strain, stest in zip(train_seq, test_seq):
229 # s = torch.cat((strain, stest), 0)