Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemLenId(Problem):
25     def __init__(self, len_max=10):
26         self.len_max = len_max
27
28     def generate_sequences(self, nb):
29         k = torch.arange(self.len_max * 3 + 3)[None, :]
30         l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
31         i = torch.randint(10, (2, nb))[:, :, None]
32         a = l[0]
33         b = l[0] + 1 + l[1]
34         c = l[0] + 1 + l[1] + 1 + l[0]
35         sequences = (
36             (k < a) * i[0]
37             + (k == a) * 10
38             + (k > a) * (k < b) * i[1]
39             + (k == b) * 11
40             + (k > b) * (k < c) * i[1]
41             + (k >= c) * 12
42         )
43         ar_mask = (sequences == 11).long()
44         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
45         return sequences, ar_mask
46
47     def seq2str(self, seq):
48         return "".join("0123456789|>_"[x.item()] for x in seq)
49
50
51 ####################
52
53
54 class ProblemLevel0(Problem):
55     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
56         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
57         self.seq[:, len_prompt] = 10
58
59     def generate_sequences(self, nb):
60         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
61         ar_mask = (sequences == 10).long()
62         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
63         return sequences, ar_mask
64
65     def seq2str(self, seq):
66         return "".join("0123456789|"[x.item()] for x in seq)
67
68
69 ####################
70
71
72 class ProblemLevel1(Problem):
73     def __init__(self, nb_operators=100, len_source=5, len_result=8):
74         self.len_source = len_source
75         self.len_result = len_result
76         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
77         self.operators = F.one_hot(
78             torch.rand(nb_operators, len_result, len_source).argmax(-1),
79             num_classes=len_source,
80         )
81
82     def generate_sequences(self, nb):
83         nb_operators = torch.randint(self.operators.size(0), (nb,))
84         operators = self.operators[nb_operators]
85         nb_operators = (
86             nb_operators[:, None]
87             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
88         ) % 10
89         marker1 = torch.full((nb, 1), 10)
90         # source = torch.randint(10, (nb, self.len_source))
91         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
92         marker2 = torch.full((nb, 1), 11)
93         result = operators.bmm(source[:, :, None]).squeeze(-1)
94         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
95         ar_mask = (sequences == 11).long()
96         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
97         return sequences, ar_mask
98
99     def seq2str(self, seq):
100         return "".join("0123456789|>"[x.item()] for x in seq)
101
102
103 ####################
104
105
106 class ProblemLevel2(Problem):
107     def __init__(self, len_source=5, len_result=8):
108         self.len_source = len_source
109         self.len_result = len_result
110
111     def generate_sequences(self, nb):
112         operators = F.one_hot(
113             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
114             num_classes=self.len_source,
115         )
116         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
117         marker1 = torch.full((nb, 1), 10)
118         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
119         marker2 = torch.full((nb, 1), 11)
120         source2 = torch.randint(10, (nb, self.len_source))
121         marker3 = torch.full((nb, 1), 12)
122         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
123
124         sequences = torch.cat(
125             (source1, marker1, result1, marker2, source2, marker3, result2), 1
126         )
127         ar_mask = (sequences == 12).long()
128         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
129         return sequences, ar_mask
130
131     def seq2str(self, seq):
132         return "".join("0123456789>|~"[x.item()] for x in seq)
133
134
135 ####################
136
137
138 class ProblemAddition(Problem):
139     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
140         self.nb_digits = nb_digits
141         self.zero_padded = zero_padded
142         self.inverted_result = inverted_result
143         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
144         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
145
146     def tensorize(self, strings):
147         len_max = max([len(x) for x in strings])
148         return torch.cat(
149             [
150                 torch.tensor(
151                     [
152                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
153                         for s in strings
154                     ]
155                 )
156             ],
157             0,
158         )
159
160     def generate_sequences(self, nb):
161         sequences = []
162         for k in range(nb):
163             a, b = torch.randint(10**self.nb_digits, (2,))
164             c = a + b
165             a, b, c = str(a.item()), str(b.item()), str(c.item())
166             if self.zero_padded:
167                 a = "0" * (self.nb_digits - len(a)) + a
168                 b = "0" * (self.nb_digits - len(b)) + b
169                 c = "0" * (self.nb_digits + 1 - len(c)) + c
170             if self.inverted_result:
171                 c = c[::-1]
172             sequences.append(f"{a}+{b}={c}$")
173
174         sequences = self.tensorize(sequences)
175         ar_mask = (sequences == self.char2id["="]).long()
176         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
177         return sequences, ar_mask
178
179     def seq2str(self, seq):
180         return "".join(self.id2char[x.item()] for x in seq)
181
182
183 # class ProblemUnion(Problem):
184 # problems = [ProblemByheart()]
185 # nb_common_codes = 100
186
187 # def generate_sequences(nb_samples):
188 # problem_indexes = torch.randint(len(problems), (nb_samples,))
189 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
190 # print(f"{nb_samples_per_problem}")
191 # all_seq = []
192 # for nb, p in zip(nb_samples_per_problem, problems):
193 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
194 # return all_seq
195
196 # for strain, stest in zip(train_seq, test_seq):
197 # s = torch.cat((strain, stest), 0)