From f031f7f87b5be907df081395023a9acba8ba9c7c Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 27 Jul 2022 18:52:54 +0200 Subject: [PATCH] Fixed stuff. --- main.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 5f3e8cf..1b011a2 100755 --- a/main.py +++ b/main.py @@ -126,7 +126,7 @@ def autoregression( results = torch.cat((primer, results), 1) for input in results.split(batch_size): - for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): + for s in range(first, input.size(1)): output = model(input) logits = output[:, s] if args.synthesis_sampling: @@ -157,6 +157,10 @@ import picoclvr class TaskPicoCLVR(Task): def tensorize(self, descr): + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] t = [ [ self.token2id[u] for u in s ] for s in descr ] return torch.tensor(t, device = self.device) @@ -165,19 +169,12 @@ class TaskPicoCLVR(Task): device = torch.device('cpu')): def generate_descr(nb): - descr = picoclvr.generate( + return picoclvr.generate( nb, height = self.height, width = self.width, nb_colors = nb_colors ) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - - return descr - self.height = height self.width = width self.batch_size = batch_size @@ -188,10 +185,10 @@ class TaskPicoCLVR(Task): self.test_descr = generate_descr((nb * 1) // 5) # Build the tokenizer - tokens = set() + tokens = { '' } for d in [ self.train_descr, self.test_descr ]: for s in d: - for t in s: tokens.add(t) + for t in s.strip().split(' '): tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) @@ -223,8 +220,8 @@ class TaskPicoCLVR(Task): for k in range(nb_per_primer): results = autoregression( model, self.batch_size, - nb_samples = 1, nb_tokens = nb_tokens, - primer = self.tensorize(primer_descr), + nb_samples = 1, nb_tokens_to_generate = nb_tokens, + primer = self.tensorize([ primer_descr ]), device = self.device ) r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) -- 2.20.1