From: Francois Fleuret Date: Tue, 26 Jul 2022 15:16:19 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=0553fc413c4af68bc777b70a6236c622a3b5242f Update. --- diff --git a/main.py b/main.py index 7ce80a3..1cd7342 100755 --- a/main.py +++ b/main.py @@ -118,17 +118,18 @@ def autoregression( nb_samples, nb_tokens_to_generate, starting_input = None, device = torch.device('cpu') ): - first = 0 results = torch.zeros( nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device ) - if starting_input is not None: + if starting_input is None: + first = 0 + else: first = starting_input.size(1) results = torch.cat((starting_input, results), 1) - for input in results.split(self.batch_size): + for input in results.split(args.batch_size): for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s]