X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=ee44ebe9ed4e1416def886b44e333d15947ebd8d;hb=f08778775c6137993f45396408b1a50bf023e5be;hp=f6934b78cd497f1f6f1ed47c1460ba20005fb333;hpb=13a6ecc6e00a75ce5a95c54c11ce6f60902f57f1;p=mygpt.git diff --git a/main.py b/main.py index f6934b7..ee44ebe 100755 --- a/main.py +++ b/main.py @@ -65,8 +65,8 @@ parser.add_argument('--nb_blocks', parser.add_argument('--dropout', type = float, default = 0.1) -parser.add_argument('--synthesis_sampling', - action='store_true', default = True) +parser.add_argument('--deterministic_synthesis', + action='store_true', default = False) parser.add_argument('--no_checkpoint', action='store_true', default = False) @@ -132,11 +132,11 @@ def autoregression( for s in range(first, input.size(1)): output = model(input) logits = output[:, s] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax(1) + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax(1) input[:, s] = t_next return results