class ProblemLenId(Problem):
- def __init__(self, nb_sentences=100, len_max=5):
+ def __init__(self, len_max=10):
self.len_max = len_max
def generate_sequences(self, nb):
+ (k > a) * (k < b) * i[1]
+ (k == b) * 11
+ (k > b) * (k < c) * i[1]
- + (k == c) * 12
- + (k > c) * 13
+ + (k >= c) * 12
)
ar_mask = (sequences == 11).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
- return "".join("0123456789|>.?"[x.item()] for x in seq)
+ return "".join("0123456789|>_"[x.item()] for x in seq)
####################