78bb64e601785bf2e116be86bd14faa4d7021696
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20
21 ####################
22
23
24 class ProblemLevel0(Problem):
25     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
26         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
27         self.seq[:, len_prompt] = 10
28
29     def generate_sequences(self, nb):
30         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
31         ar_mask = (sequences == 10).long()
32         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
33         return sequences, ar_mask
34
35
36 class ProblemLevel1(Problem):
37     def __init__(self, nb_operators=100, len_source=5, len_result=8):
38         self.len_source = len_source
39         self.len_result = len_result
40         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
41         self.operators = F.one_hot(
42             torch.rand(nb_operators, len_result, len_source).argmax(-1),
43             num_classes=len_source,
44         )
45
46     def generate_sequences(self, nb):
47         nb_operators = torch.randint(self.operators.size(0), (nb,))
48         operators = self.operators[nb_operators]
49         nb_operators = (
50             nb_operators[:, None]
51             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
52         ) % 10
53         marker1 = torch.full((nb, 1), 10)
54         # source = torch.randint(10, (nb, self.len_source))
55         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
56         marker2 = torch.full((nb, 1), 11)
57         result = operators.bmm(source[:, :, None]).squeeze(-1)
58         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
59         ar_mask = (sequences == 11).long()
60         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
61         return sequences, ar_mask
62
63     def seq2str(self, seq):
64         return "".join("0123456789|>"[x.item()] for x in seq)
65
66
67 class ProblemLevel2(Problem):
68     def __init__(self, len_source=5, len_result=8):
69         self.len_source = len_source
70         self.len_result = len_result
71
72     def generate_sequences(self, nb):
73         operators = F.one_hot(
74             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
75             num_classes=self.len_source,
76         )
77         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
78         marker1 = torch.full((nb, 1), 10)
79         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
80         marker2 = torch.full((nb, 1), 11)
81         source2 = torch.randint(10, (nb, self.len_source))
82         marker3 = torch.full((nb, 1), 12)
83         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
84
85         sequences = torch.cat(
86             (source1, marker1, result1, marker2, source2, marker3, result2), 1
87         )
88         ar_mask = (sequences == 12).long()
89         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
90         return sequences, ar_mask
91
92     def seq2str(self, seq):
93         return "".join("0123456789>|~"[x.item()] for x in seq)
94
95
96 ####################
97
98
99 class ProblemAddition(Problem):
100     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
101         self.nb_digits = nb_digits
102         self.zero_padded = zero_padded
103         self.inverted_result = inverted_result
104         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
105         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
106
107     def tensorize(self, strings):
108         len_max = max([len(x) for x in strings])
109         return torch.cat(
110             [
111                 torch.tensor(
112                     [
113                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
114                         for s in strings
115                     ]
116                 )
117             ],
118             0,
119         )
120
121     def generate_sequences(self, nb):
122         sequences = []
123         for k in range(nb):
124             a, b = torch.randint(10**self.nb_digits, (2,))
125             c = a + b
126             a, b, c = str(a.item()), str(b.item()), str(c.item())
127             if self.zero_padded:
128                 a = "0" * (self.nb_digits - len(a)) + a
129                 b = "0" * (self.nb_digits - len(b)) + b
130                 c = "0" * (self.nb_digits + 1 - len(c)) + c
131             if self.inverted_result:
132                 c = c[::-1]
133             sequences.append(f"{a}+{b}={c}$")
134
135         sequences = self.tensorize(sequences)
136         ar_mask = (sequences == self.char2id["="]).long()
137         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
138         return sequences, ar_mask
139
140     def seq2str(self, seq):
141         return "".join(self.id2char[x.item()] for x in seq)
142
143
144 # class ProblemUnion(Problem):
145 # problems = [ProblemByheart()]
146 # nb_common_codes = 100
147
148 # def generate_sequences(nb_samples):
149 # problem_indexes = torch.randint(len(problems), (nb_samples,))
150 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
151 # print(f"{nb_samples_per_problem}")
152 # all_seq = []
153 # for nb, p in zip(nb_samples_per_problem, problems):
154 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
155 # return all_seq
156
157 # for strain, stest in zip(train_seq, test_seq):
158 # s = torch.cat((strain, stest), 0)
159