Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 3 Dec 2022 14:29:16 +0000 (08:29 -0600)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 3 Dec 2022 14:29:16 +0000 (08:29 -0600)
main.py

diff --git a/main.py b/main.py
index ee44ebe..b6eb6fe 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -349,11 +349,11 @@ class TaskWiki103(Task):
                      input = F.pad(input, (0, 1)) # Add the next token, the one to predict
                      output = model(input)
                      logits = output[0, -1]
-                     if args.synthesis_sampling:
+                     if args.deterministic_synthesis:
+                         t_next = logits.argmax()
+                     else:
                          dist = torch.distributions.categorical.Categorical(logits = logits)
                          t_next = dist.sample()
-                     else:
-                         t_next = logits.argmax()
                      t_generated.append(self.vocab.lookup_token(t_next))
                      if t_generated[-1] == '<nul>': break