X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=dca201fdce93d37d714f99944264c436c4b8219a;hb=d2844d7a2d09ef38dc6f62d5e131059cccc872c5;hp=1f4098a831ef120cb3f97ed22a0b6ba8dc090c1a;hpb=6c8bed86221baae24a7c2aaaa41c009444efb5c9;p=picoclvr.git diff --git a/problems.py b/problems.py index 1f4098a..dca201f 100755 --- a/problems.py +++ b/problems.py @@ -22,7 +22,7 @@ class Problem: class ProblemLenId(Problem): - def __init__(self, nb_sentences=100, len_max=5): + def __init__(self, len_max=10): self.len_max = len_max def generate_sequences(self, nb): @@ -38,15 +38,14 @@ class ProblemLenId(Problem): + (k > a) * (k < b) * i[1] + (k == b) * 11 + (k > b) * (k < c) * i[1] - + (k == c) * 12 - + (k > c) * 13 + + (k >= c) * 12 ) ar_mask = (sequences == 11).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask def seq2str(self, seq): - return "".join("0123456789|>.?"[x.item()] for x in seq) + return "".join("0123456789|>_"[x.item()] for x in seq) ####################