X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=427a83a05c6aef7f55413a7a1f225fb11efbb451;hb=6260a3593ac09a9bbdd9c85b23d78a71fa028acd;hp=1b011a2f4a355b4f7ec3f43215739b740c0da43a;hpb=f031f7f87b5be907df081395023a9acba8ba9c7c;p=mygpt.git diff --git a/main.py b/main.py index 1b011a2..427a83a 100755 --- a/main.py +++ b/main.py @@ -156,13 +156,14 @@ import picoclvr class TaskPicoCLVR(Task): + # Make a tensor from a list of strings def tensorize(self, descr): - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - t = [ [ self.token2id[u] for u in s ] for s in descr ] - return torch.tensor(t, device = self.device) + token_descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in token_descr ]) + #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] + token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] + id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] + return torch.tensor(id_descr, device = self.device) def __init__(self, batch_size, height, width, nb_colors = 5, @@ -281,6 +282,7 @@ class TaskWiki103(Task): self.vocab.set_default_index(self.vocab[ '' ]) + # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ])