From: François Fleuret Date: Sat, 3 Dec 2022 14:29:16 +0000 (-0600) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=3ae0c8f3767e4285ab548e4548576a6ddf6003bb Update. --- diff --git a/main.py b/main.py index ee44ebe..b6eb6fe 100755 --- a/main.py +++ b/main.py @@ -349,11 +349,11 @@ class TaskWiki103(Task): input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax() + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax() t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break