From 73507ec9ff7677eefdafabe4123c9b05cab28f8f Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 16 Jul 2022 11:49:57 +0200 Subject: [PATCH] Update. --- main.py | 44 ++++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 11cf0a3..a18beb1 100755 --- a/main.py +++ b/main.py @@ -131,6 +131,9 @@ class TaskPicoCLVR(Task): many_colors = many_colors ) + self.test_descr = descr[:nb // 5] + self.train_descr = descr[nb // 5:] + descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) descr = [ s + [ '' ] * (l - len(s)) for s in descr ] @@ -159,8 +162,26 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) + def generate(self, primer, model, nb_tokens): + t_primer = primer.strip().split(' ') + t_generated = [ ] + + for j in range(nb_tokens): + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + input = torch.tensor(t, device = self.device) + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax() + t_generated.append(self.id2token[t.item()]) + + return ' '.join(t_primer + t_generated) + def produce_results(self, n_epoch, model, nb_tokens = 50): - img = [ ] + descr = [ ] nb_per_primer = 8 for primer in [ @@ -171,30 +192,17 @@ class TaskPicoCLVR(Task): ]: for k in range(nb_per_primer): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - descr = [ ' '.join(t_primer + t_generated) ] - img += [ picoclvr.descr2img(descr) ] + descr.append(self.generate(primer, model, nb_tokens)) + img = [ picoclvr.descr2img(d) for d in descr ] img = torch.cat(img, 0) file_name = f'result_picoclvr_{n_epoch:04d}.png' torchvision.utils.save_image(img / 255., file_name, nrow = nb_per_primer, pad_value = 0.8) log_string(f'wrote {file_name}') + log_string(f'nb_misssing {picoclvr.nb_missing_properties(descr)}') + ###################################################################### class TaskWiki103(Task): -- 2.39.5