From 086ec8f8d2ffeaac270fbedd991bb79122db7fdf Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 17:06:43 +0200 Subject: [PATCH] Update. --- main.py | 64 ++++++++++++++++++++++++++--------------------------- picoclvr.py | 4 ++-- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/main.py b/main.py index a83107b..7ce80a3 100755 --- a/main.py +++ b/main.py @@ -78,8 +78,8 @@ parser.add_argument('--checkpoint_name', ############################## # picoclvr options -parser.add_argument('--picoclvr_many_colors', - action='store_true', default = False) +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) parser.add_argument('--picoclvr_height', type = int, default = 12) @@ -113,22 +113,33 @@ for n in vars(args): ###################################################################### -def produce_results( - self, - model, nb_samples, nb_tokens_to_generate, starting_input = None, - device = 'cpu' +def autoregression( + model, + nb_samples, nb_tokens_to_generate, starting_input = None, + device = torch.device('cpu') ): - results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device) + first = 0 + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if starting_input is not None: + first = starting_input.size(1) + results = torch.cat((starting_input, results), 1) + for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax(1) - input[:, s + 1] = t + t_next = logits.argmax(1) + input[:, s] = t_next + + return results ###################################################################### @@ -149,14 +160,14 @@ import picoclvr class TaskPicoCLVR(Task): def __init__(self, batch_size, - height, width, many_colors = False, + height, width, nb_colors = 5, device = torch.device('cpu')): def generate_descr(nb): descr = picoclvr.generate( nb, height = self.height, width = self.width, - many_colors = many_colors + nb_colors = nb_colors ) descr = [ s.strip().split(' ') for s in descr ] @@ -211,10 +222,10 @@ class TaskPicoCLVR(Task): logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) + t_next = logits.argmax() + t_generated.append(self.id2token[t_next.item()]) return ' '.join(t_primer + t_generated) @@ -339,10 +350,10 @@ class TaskWiki103(Task): logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) + t_next = logits.argmax() + t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -375,18 +386,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s] = t - + results = autoregression(model, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -407,7 +407,7 @@ elif args.data == 'picoclvr': task = TaskPicoCLVR(batch_size = args.batch_size, height = args.picoclvr_height, width = args.picoclvr_width, - many_colors = args.picoclvr_many_colors, + nb_colors = args.picoclvr_nb_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') diff --git a/picoclvr.py b/picoclvr.py index 19517af..2d57505 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -95,9 +95,9 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): def generate(nb, height, width, max_nb_squares = 5, max_nb_properties = 10, - many_colors = False): + nb_colors = 5): - nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares + assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 descr = [ ] -- 2.20.1