projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
13a6ecc
)
Replaced --synthesis_sampling with --deterministic_synthesis.
author
Francois Fleuret
<francois@fleuret.org>
Sat, 20 Aug 2022 05:47:14 +0000
(07:47 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sat, 20 Aug 2022 05:47:14 +0000
(07:47 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
f6934b7
..
ee44ebe
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-65,8
+65,8
@@
parser.add_argument('--nb_blocks',
parser.add_argument('--dropout',
type = float, default = 0.1)
parser.add_argument('--dropout',
type = float, default = 0.1)
-parser.add_argument('--
synthesis_sampling
',
- action='store_true', default =
Tru
e)
+parser.add_argument('--
deterministic_synthesis
',
+ action='store_true', default =
Fals
e)
parser.add_argument('--no_checkpoint',
action='store_true', default = False)
parser.add_argument('--no_checkpoint',
action='store_true', default = False)
@@
-132,11
+132,11
@@
def autoregression(
for s in range(first, input.size(1)):
output = model(input)
logits = output[:, s]
for s in range(first, input.size(1)):
output = model(input)
logits = output[:, s]
- if args.synthesis_sampling:
+ if args.deterministic_synthesis:
+ t_next = logits.argmax(1)
+ else:
dist = torch.distributions.categorical.Categorical(logits = logits)
t_next = dist.sample()
dist = torch.distributions.categorical.Categorical(logits = logits)
t_next = dist.sample()
- else:
- t_next = logits.argmax(1)
input[:, s] = t_next
return results
input[:, s] = t_next
return results