Update.
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemLenId(Problem):
25     def __init__(self, nb_sentences=100, len_max=5):
26         self.len_max = len_max
27
28     def generate_sequences(self, nb):
29         k = torch.arange(self.len_max * 3 + 3)[None, :]
30         l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
31         i = torch.randint(10, (2, nb))[:, :, None]
32         a = l[0]
33         b = l[0] + 1 + l[1]
34         c = l[0] + 1 + l[1] + 1 + l[0]
35         sequences = (
36             (k < a) * i[0]
37             + (k == a) * 10
38             + (k > a) * (k < b) * i[1]
39             + (k == b) * 11
40             + (k > b) * (k < c) * i[1]
41             + (k == c) * 12
42             + (k > c) * 13
43         )
44         ar_mask = (sequences == 11).long()
45         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
46         return sequences, ar_mask
47
48     def seq2str(self, seq):
49         return "".join("0123456789|>.?"[x.item()] for x in seq)
50
51
52 ####################
53
54
55 class ProblemLevel0(Problem):
56     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
57         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
58         self.seq[:, len_prompt] = 10
59
60     def generate_sequences(self, nb):
61         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
62         ar_mask = (sequences == 10).long()
63         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
64         return sequences, ar_mask
65
66     def seq2str(self, seq):
67         return "".join("0123456789|"[x.item()] for x in seq)
68
69
70 ####################
71
72
73 class ProblemLevel1(Problem):
74     def __init__(self, nb_operators=100, len_source=5, len_result=8):
75         self.len_source = len_source
76         self.len_result = len_result
77         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
78         self.operators = F.one_hot(
79             torch.rand(nb_operators, len_result, len_source).argmax(-1),
80             num_classes=len_source,
81         )
82
83     def generate_sequences(self, nb):
84         nb_operators = torch.randint(self.operators.size(0), (nb,))
85         operators = self.operators[nb_operators]
86         nb_operators = (
87             nb_operators[:, None]
88             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
89         ) % 10
90         marker1 = torch.full((nb, 1), 10)
91         # source = torch.randint(10, (nb, self.len_source))
92         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
93         marker2 = torch.full((nb, 1), 11)
94         result = operators.bmm(source[:, :, None]).squeeze(-1)
95         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
96         ar_mask = (sequences == 11).long()
97         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
98         return sequences, ar_mask
99
100     def seq2str(self, seq):
101         return "".join("0123456789|>"[x.item()] for x in seq)
102
103
104 ####################
105
106
107 class ProblemLevel2(Problem):
108     def __init__(self, len_source=5, len_result=8):
109         self.len_source = len_source
110         self.len_result = len_result
111
112     def generate_sequences(self, nb):
113         operators = F.one_hot(
114             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
115             num_classes=self.len_source,
116         )
117         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
118         marker1 = torch.full((nb, 1), 10)
119         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
120         marker2 = torch.full((nb, 1), 11)
121         source2 = torch.randint(10, (nb, self.len_source))
122         marker3 = torch.full((nb, 1), 12)
123         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
124
125         sequences = torch.cat(
126             (source1, marker1, result1, marker2, source2, marker3, result2), 1
127         )
128         ar_mask = (sequences == 12).long()
129         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
130         return sequences, ar_mask
131
132     def seq2str(self, seq):
133         return "".join("0123456789>|~"[x.item()] for x in seq)
134
135
136 ####################
137
138
139 class ProblemAddition(Problem):
140     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
141         self.nb_digits = nb_digits
142         self.zero_padded = zero_padded
143         self.inverted_result = inverted_result
144         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
145         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
146
147     def tensorize(self, strings):
148         len_max = max([len(x) for x in strings])
149         return torch.cat(
150             [
151                 torch.tensor(
152                     [
153                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
154                         for s in strings
155                     ]
156                 )
157             ],
158             0,
159         )
160
161     def generate_sequences(self, nb):
162         sequences = []
163         for k in range(nb):
164             a, b = torch.randint(10**self.nb_digits, (2,))
165             c = a + b
166             a, b, c = str(a.item()), str(b.item()), str(c.item())
167             if self.zero_padded:
168                 a = "0" * (self.nb_digits - len(a)) + a
169                 b = "0" * (self.nb_digits - len(b)) + b
170                 c = "0" * (self.nb_digits + 1 - len(c)) + c
171             if self.inverted_result:
172                 c = c[::-1]
173             sequences.append(f"{a}+{b}={c}$")
174
175         sequences = self.tensorize(sequences)
176         ar_mask = (sequences == self.char2id["="]).long()
177         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
178         return sequences, ar_mask
179
180     def seq2str(self, seq):
181         return "".join(self.id2char[x.item()] for x in seq)
182
183
184 # class ProblemUnion(Problem):
185 # problems = [ProblemByheart()]
186 # nb_common_codes = 100
187
188 # def generate_sequences(nb_samples):
189 # problem_indexes = torch.randint(len(problems), (nb_samples,))
190 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
191 # print(f"{nb_samples_per_problem}")
192 # all_seq = []
193 # for nb, p in zip(nb_samples_per_problem, problems):
194 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
195 # return all_seq
196
197 # for strain, stest in zip(train_seq, test_seq):
198 # s = torch.cat((strain, stest), 0)