From: Francois Fleuret Date: Wed, 27 Jul 2022 14:07:36 +0000 (+0200) Subject: OCD cosmectics X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=b4c255babaae72d6a03b4c8e8e7e25f6ab0a19a0 OCD cosmectics --- diff --git a/main.py b/main.py index 76aeebd..77b1b22 100755 --- a/main.py +++ b/main.py @@ -212,10 +212,10 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) - def generate(self, descr_primer, model, nb_tokens): + def generate(self, primer_descr, model, nb_tokens): results = autoregression( model, self.batch_size, - 1, nb_tokens, primer = descr2tensor(descr_primer), + 1, nb_tokens, primer = descr2tensor(primer_descr), device = self.device ) return ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) @@ -226,7 +226,7 @@ class TaskPicoCLVR(Task): result_descr = [ ] nb_per_primer = 8 - for descr_primer in [ + for primer_descr in [ 'red above green green top blue right of red ', 'there is red there is yellow there is blue ', 'red below yellow yellow below green green below blue red right yellow left green right blue left ', @@ -234,7 +234,7 @@ class TaskPicoCLVR(Task): ]: for k in range(nb_per_primer): - result_descr.append(self.generate(descr_primer, model, nb_tokens)) + result_descr.append(self.generate(primer_descr, model, nb_tokens)) img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in result_descr ]