From 3ae0c8f3767e4285ab548e4548576a6ddf6003bb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 3 Dec 2022 08:29:16 -0600 Subject: [PATCH] Update. --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index ee44ebe..b6eb6fe 100755 --- a/main.py +++ b/main.py @@ -349,11 +349,11 @@ class TaskWiki103(Task): input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax() + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax() t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break -- 2.20.1