From 98d2184fb3f202d0f513380ca00d080b64cf5e90 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 25 Jul 2022 15:31:58 +0200 Subject: [PATCH] Update. --- main.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 3bf6b52..ace376d 100755 --- a/main.py +++ b/main.py @@ -72,9 +72,18 @@ parser.add_argument('--synthesis_sampling', parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') +############################## +# picoclvr options + parser.add_argument('--picoclvr_many_colors', action='store_true', default = False) +parser.add_argument('--picoclvr_height', + type = int, default = 12) + +parser.add_argument('--picoclvr_width', + type = int, default = 16) + ###################################################################### args = parser.parse_args() @@ -118,16 +127,18 @@ import picoclvr class TaskPicoCLVR(Task): def __init__(self, batch_size, - height = 6, width = 8, many_colors = False, + height, width, many_colors = False, device = torch.device('cpu')): + self.height = height + self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 descr = picoclvr.generate( nb, - height = height, width = width, + height = self.height, width = self.width, many_colors = many_colors ) @@ -180,7 +191,9 @@ class TaskPicoCLVR(Task): return ' '.join(t_primer + t_generated) - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model, nb_tokens = None): + if nb_tokens is None: + nb_tokens = self.height * self.width + 3 descr = [ ] nb_per_primer = 8 @@ -194,14 +207,14 @@ class TaskPicoCLVR(Task): for k in range(nb_per_primer): descr.append(self.generate(primer, model, nb_tokens)) - img = [ picoclvr.descr2img(d) for d in descr ] + img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] img = torch.cat(img, 0) file_name = f'result_picoclvr_{n_epoch:04d}.png' torchvision.utils.save_image(img / 255., file_name, nrow = nb_per_primer, pad_value = 0.8) log_string(f'wrote {file_name}') - nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr) ] ) + nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr, height = self.height, width = self.width) ] ) log_string(f'nb_missing {nb_missing / len(descr):.02f}') ###################################################################### @@ -365,7 +378,11 @@ if args.data == 'wiki103': elif args.data == 'mnist': task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device) + task = TaskPicoCLVR(batch_size = args.batch_size, + height = args.picoclvr_height, + width = args.picoclvr_width, + many_colors = args.picoclvr_many_colors, + device = device) else: raise ValueError(f'Unknown dataset {args.data}.') -- 2.20.1