- sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10
- sequences[:, a] = 10
- sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
- sequences[:, b] = 11
-
- o = self.operators[nb_operators]
- p = sequences[:, a + 1 : b]
- print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}")
- sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1)
+ operators = self.operators[nb_operators]
+ nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10
+ marker1 = torch.full((nb,1),10)
+ source = torch.randint(10, (nb, self.len_source))
+ marker2 = torch.full((nb,1),11)
+ result = operators.bmm(source[:, :, None]).squeeze(-1)
+ print(f"{nb_operators.dtype=} {marker1.dtype=}")
+ sequences = torch.cat((nb_operators, marker1, source,marker2,result),1)
+ print(f"{sequences.size()=}")