5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
24 class ProblemTwoTargets(Problem):
25 def __init__(self, len_total=10, len_targets=3):
26 assert len_targets >= 3
27 assert len_total >= 3 * len_targets - 1
28 self.len_total = len_total
29 self.len_targets = len_targets
31 def generate_sequences(self, nb):
32 k = torch.arange(self.len_total)[None, :]
33 s = torch.randint(10, (nb, self.len_total))
34 l = torch.rand(nb, self.len_total)
35 l = l * (k <= self.len_total - self.len_targets).long()
36 k1 = l.argmax(dim=1, keepdim=True)
37 m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
38 s = s * m + 10 * (1 - m)
41 - (k + self.len_targets - 1 >= k1).long()
42 * (k < k1 + self.len_targets).long()
44 k2 = l.argmax(dim=1, keepdim=True)
45 m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
46 s = s * m + 11 * (1 - m)
47 a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
48 a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
49 sequences = torch.cat(
52 torch.full((nb, 1), 12),
54 torch.full((nb, 1), 12),
56 torch.full((nb, 1), 12),
60 ar_mask = (sequences == 12).long()
61 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
62 return sequences, ar_mask
64 def seq2str(self, seq):
65 return "".join("0123456789-+|"[x.item()] for x in seq)
71 class ProblemLenId(Problem):
72 def __init__(self, len_max=10):
73 self.len_max = len_max
75 def generate_sequences(self, nb):
76 k = torch.arange(self.len_max * 3 + 3)[None, :]
77 l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
78 i = torch.randint(10, (2, nb))[:, :, None]
81 c = l[0] + 1 + l[1] + 1 + l[0]
85 + (k > a) * (k < b) * i[1]
87 + (k > b) * (k < c) * i[1]
90 ar_mask = (sequences == 11).long()
91 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
92 return sequences, ar_mask
94 def seq2str(self, seq):
95 return "".join("0123456789|>_"[x.item()] for x in seq)
101 class ProblemLevel0(Problem):
102 def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
103 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
104 self.seq[:, len_prompt] = 10
106 def generate_sequences(self, nb):
107 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
108 ar_mask = (sequences == 10).long()
109 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
110 return sequences, ar_mask
112 def seq2str(self, seq):
113 return "".join("0123456789|"[x.item()] for x in seq)
119 class ProblemLevel1(Problem):
120 def __init__(self, nb_operators=100, len_source=5, len_result=8):
121 self.len_source = len_source
122 self.len_result = len_result
123 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
124 self.operators = F.one_hot(
125 torch.rand(nb_operators, len_result, len_source).argmax(-1),
126 num_classes=len_source,
129 def generate_sequences(self, nb):
130 nb_operators = torch.randint(self.operators.size(0), (nb,))
131 operators = self.operators[nb_operators]
133 nb_operators[:, None]
134 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
136 marker1 = torch.full((nb, 1), 10)
137 # source = torch.randint(10, (nb, self.len_source))
138 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
139 marker2 = torch.full((nb, 1), 11)
140 result = operators.bmm(source[:, :, None]).squeeze(-1)
141 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
142 ar_mask = (sequences == 11).long()
143 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
144 return sequences, ar_mask
146 def seq2str(self, seq):
147 return "".join("0123456789|>"[x.item()] for x in seq)
153 class ProblemLevel2(Problem):
154 def __init__(self, len_source=5, len_result=8):
155 self.len_source = len_source
156 self.len_result = len_result
158 def generate_sequences(self, nb):
159 operators = F.one_hot(
160 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
161 num_classes=self.len_source,
163 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
164 marker1 = torch.full((nb, 1), 10)
165 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
166 marker2 = torch.full((nb, 1), 11)
167 source2 = torch.randint(10, (nb, self.len_source))
168 marker3 = torch.full((nb, 1), 12)
169 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
171 sequences = torch.cat(
172 (source1, marker1, result1, marker2, source2, marker3, result2), 1
174 ar_mask = (sequences == 12).long()
175 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
176 return sequences, ar_mask
178 def seq2str(self, seq):
179 return "".join("0123456789>|~"[x.item()] for x in seq)
185 class ProblemAddition(Problem):
186 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
187 self.nb_digits = nb_digits
188 self.zero_padded = zero_padded
189 self.inverted_result = inverted_result
190 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
191 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
193 def tensorize(self, strings):
194 len_max = max([len(x) for x in strings])
199 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
207 def generate_sequences(self, nb):
210 a, b = torch.randint(10**self.nb_digits, (2,))
212 a, b, c = str(a.item()), str(b.item()), str(c.item())
214 a = "0" * (self.nb_digits - len(a)) + a
215 b = "0" * (self.nb_digits - len(b)) + b
216 c = "0" * (self.nb_digits + 1 - len(c)) + c
217 if self.inverted_result:
219 sequences.append(f"{a}+{b}={c}$")
221 sequences = self.tensorize(sequences)
222 ar_mask = (sequences == self.char2id["="]).long()
223 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
224 return sequences, ar_mask
226 def seq2str(self, seq):
227 return "".join(self.id2char[x.item()] for x in seq)
230 if __name__ == "__main__":
231 p = ProblemTwoTargets(12, 4)
232 s, m = p.generate_sequences(10)