Update.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 15:16:19 +0000 (17:16 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 15:16:19 +0000 (17:16 +0200)
main.py

diff --git a/main.py b/main.py
index 7ce80a3..1cd7342 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -118,17 +118,18 @@ def autoregression(
         nb_samples, nb_tokens_to_generate, starting_input = None,
         device = torch.device('cpu')
 ):
-    first = 0
     results = torch.zeros(
         nb_samples, nb_tokens_to_generate,
         dtype = torch.int64, device = device
     )
 
-    if starting_input is not None:
+    if starting_input is None:
+        first = 0
+    else:
         first = starting_input.size(1)
         results = torch.cat((starting_input, results), 1)
 
-    for input in results.split(self.batch_size):
+    for input in results.split(args.batch_size):
         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
             output = model(input)
             logits = output[:, s]