From: Francois Fleuret Date: Sat, 20 Aug 2022 05:47:14 +0000 (+0200) Subject: Replaced --synthesis_sampling with --deterministic_synthesis. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=f08778775c6137993f45396408b1a50bf023e5be Replaced --synthesis_sampling with --deterministic_synthesis. --- diff --git a/main.py b/main.py index f6934b7..ee44ebe 100755 --- a/main.py +++ b/main.py @@ -65,8 +65,8 @@ parser.add_argument('--nb_blocks', parser.add_argument('--dropout', type = float, default = 0.1) -parser.add_argument('--synthesis_sampling', - action='store_true', default = True) +parser.add_argument('--deterministic_synthesis', + action='store_true', default = False) parser.add_argument('--no_checkpoint', action='store_true', default = False) @@ -132,11 +132,11 @@ def autoregression( for s in range(first, input.size(1)): output = model(input) logits = output[:, s] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax(1) + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax(1) input[:, s] = t_next return results