X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=b579177623ba6c7bbf897b06279be6f4d6ed663a;hb=02a6cbfdda55ba2ce13ff0f925009c15cc2a0a90;hp=11cf0a31401eb0c21aff0875564f80d2e7270db3;hpb=c8d0cf6842db19f84a78c1b3a4d2666b323a5d4a;p=mygpt.git diff --git a/main.py b/main.py index 11cf0a3..b579177 100755 --- a/main.py +++ b/main.py @@ -72,9 +72,18 @@ parser.add_argument('--synthesis_sampling', parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') +############################## +# picoclvr options + parser.add_argument('--picoclvr_many_colors', action='store_true', default = False) +parser.add_argument('--picoclvr_height', + type = int, default = 12) + +parser.add_argument('--picoclvr_width', + type = int, default = 16) + ###################################################################### args = parser.parse_args() @@ -118,19 +127,24 @@ import picoclvr class TaskPicoCLVR(Task): def __init__(self, batch_size, - height = 6, width = 8, many_colors = False, + height, width, many_colors = False, device = torch.device('cpu')): + self.height = height + self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 descr = picoclvr.generate( nb, - height = height, width = width, + height = self.height, width = self.width, many_colors = many_colors ) + # self.test_descr = descr[:nb // 5] + # self.train_descr = descr[nb // 5:] + descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) descr = [ s + [ '' ] * (l - len(s)) for s in descr ] @@ -159,8 +173,28 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) - def produce_results(self, n_epoch, model, nb_tokens = 50): - img = [ ] + def generate(self, primer, model, nb_tokens): + t_primer = primer.strip().split(' ') + t_generated = [ ] + + for j in range(nb_tokens): + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + input = torch.tensor(t, device = self.device) + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax() + t_generated.append(self.id2token[t.item()]) + + return ' '.join(t_primer + t_generated) + + def produce_results(self, n_epoch, model, nb_tokens = None): + if nb_tokens is None: + nb_tokens = self.height * self.width + 3 + descr = [ ] nb_per_primer = 8 for primer in [ @@ -171,30 +205,26 @@ class TaskPicoCLVR(Task): ]: for k in range(nb_per_primer): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - descr = [ ' '.join(t_primer + t_generated) ] - img += [ picoclvr.descr2img(descr) ] + descr.append(self.generate(primer, model, nb_tokens)) + img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] img = torch.cat(img, 0) file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) + torchvision.utils.save_image( + img / 255., + file_name, nrow = nb_per_primer, pad_value = 0.8 + ) log_string(f'wrote {file_name}') + nb_missing = sum( [ + x[2] for x in picoclvr.nb_missing_properties( + descr, + height = self.height, width = self.width + ) + ] ) + + log_string(f'nb_missing {nb_missing / len(descr):.02f}') + ###################################################################### class TaskWiki103(Task): @@ -356,7 +386,11 @@ if args.data == 'wiki103': elif args.data == 'mnist': task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device) + task = TaskPicoCLVR(batch_size = args.batch_size, + height = args.picoclvr_height, + width = args.picoclvr_width, + many_colors = args.picoclvr_many_colors, + device = device) else: raise ValueError(f'Unknown dataset {args.data}.')