5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
24 class ProblemLenId(Problem):
25 def __init__(self, nb_sentences=100, len_max=5):
26 self.len_max = len_max
28 def generate_sequences(self, nb):
29 k = torch.arange(self.len_max * 3 + 3)[None, :]
30 l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
31 i = torch.randint(10, (2, nb))[:, :, None]
34 c = l[0] + 1 + l[1] + 1 + l[0]
38 + (k > a) * (k < b) * i[1]
40 + (k > b) * (k < c) * i[1]
44 ar_mask = (sequences == 11).long()
45 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
46 return sequences, ar_mask
48 def seq2str(self, seq):
49 return "".join("0123456789|>.?"[x.item()] for x in seq)
55 class ProblemLevel0(Problem):
56 def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
57 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
58 self.seq[:, len_prompt] = 10
60 def generate_sequences(self, nb):
61 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
62 ar_mask = (sequences == 10).long()
63 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
64 return sequences, ar_mask
66 def seq2str(self, seq):
67 return "".join("0123456789|"[x.item()] for x in seq)
73 class ProblemLevel1(Problem):
74 def __init__(self, nb_operators=100, len_source=5, len_result=8):
75 self.len_source = len_source
76 self.len_result = len_result
77 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
78 self.operators = F.one_hot(
79 torch.rand(nb_operators, len_result, len_source).argmax(-1),
80 num_classes=len_source,
83 def generate_sequences(self, nb):
84 nb_operators = torch.randint(self.operators.size(0), (nb,))
85 operators = self.operators[nb_operators]
88 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
90 marker1 = torch.full((nb, 1), 10)
91 # source = torch.randint(10, (nb, self.len_source))
92 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
93 marker2 = torch.full((nb, 1), 11)
94 result = operators.bmm(source[:, :, None]).squeeze(-1)
95 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
96 ar_mask = (sequences == 11).long()
97 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
98 return sequences, ar_mask
100 def seq2str(self, seq):
101 return "".join("0123456789|>"[x.item()] for x in seq)
107 class ProblemLevel2(Problem):
108 def __init__(self, len_source=5, len_result=8):
109 self.len_source = len_source
110 self.len_result = len_result
112 def generate_sequences(self, nb):
113 operators = F.one_hot(
114 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
115 num_classes=self.len_source,
117 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
118 marker1 = torch.full((nb, 1), 10)
119 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
120 marker2 = torch.full((nb, 1), 11)
121 source2 = torch.randint(10, (nb, self.len_source))
122 marker3 = torch.full((nb, 1), 12)
123 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
125 sequences = torch.cat(
126 (source1, marker1, result1, marker2, source2, marker3, result2), 1
128 ar_mask = (sequences == 12).long()
129 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
130 return sequences, ar_mask
132 def seq2str(self, seq):
133 return "".join("0123456789>|~"[x.item()] for x in seq)
139 class ProblemAddition(Problem):
140 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
141 self.nb_digits = nb_digits
142 self.zero_padded = zero_padded
143 self.inverted_result = inverted_result
144 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
145 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
147 def tensorize(self, strings):
148 len_max = max([len(x) for x in strings])
153 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
161 def generate_sequences(self, nb):
164 a, b = torch.randint(10**self.nb_digits, (2,))
166 a, b, c = str(a.item()), str(b.item()), str(c.item())
168 a = "0" * (self.nb_digits - len(a)) + a
169 b = "0" * (self.nb_digits - len(b)) + b
170 c = "0" * (self.nb_digits + 1 - len(c)) + c
171 if self.inverted_result:
173 sequences.append(f"{a}+{b}={c}$")
175 sequences = self.tensorize(sequences)
176 ar_mask = (sequences == self.char2id["="]).long()
177 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
178 return sequences, ar_mask
180 def seq2str(self, seq):
181 return "".join(self.id2char[x.item()] for x in seq)
184 # class ProblemUnion(Problem):
185 # problems = [ProblemByheart()]
186 # nb_common_codes = 100
188 # def generate_sequences(nb_samples):
189 # problem_indexes = torch.randint(len(problems), (nb_samples,))
190 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
191 # print(f"{nb_samples_per_problem}")
193 # for nb, p in zip(nb_samples_per_problem, problems):
194 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
197 # for strain, stest in zip(train_seq, test_seq):
198 # s = torch.cat((strain, stest), 0)