From 0553fc413c4af68bc777b70a6236c622a3b5242f Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 17:16:19 +0200 Subject: [PATCH] Update. --- main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 7ce80a3..1cd7342 100755 --- a/main.py +++ b/main.py @@ -118,17 +118,18 @@ def autoregression( nb_samples, nb_tokens_to_generate, starting_input = None, device = torch.device('cpu') ): - first = 0 results = torch.zeros( nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device ) - if starting_input is not None: + if starting_input is None: + first = 0 + else: first = starting_input.size(1) results = torch.cat((starting_input, results), 1) - for input in results.split(self.batch_size): + for input in results.split(args.batch_size): for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] -- 2.20.1