Replaced --synthesis_sampling with --deterministic_synthesis.
authorFrancois Fleuret <francois@fleuret.org>
Sat, 20 Aug 2022 05:47:14 +0000 (07:47 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sat, 20 Aug 2022 05:47:14 +0000 (07:47 +0200)
main.py

diff --git a/main.py b/main.py
index f6934b7..ee44ebe 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -65,8 +65,8 @@ parser.add_argument('--nb_blocks',
 parser.add_argument('--dropout',
                     type = float, default = 0.1)
 
-parser.add_argument('--synthesis_sampling',
-                    action='store_true', default = True)
+parser.add_argument('--deterministic_synthesis',
+                    action='store_true', default = False)
 
 parser.add_argument('--no_checkpoint',
                     action='store_true', default = False)
@@ -132,11 +132,11 @@ def autoregression(
         for s in range(first, input.size(1)):
             output = model(input)
             logits = output[:, s]
-            if args.synthesis_sampling:
+            if args.deterministic_synthesis:
+                t_next = logits.argmax(1)
+            else:
                 dist = torch.distributions.categorical.Categorical(logits = logits)
                 t_next = dist.sample()
-            else:
-                t_next = logits.argmax(1)
             input[:, s] = t_next
 
     return results