projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Initial commit.
[pytorch.git]
/
gpt-test.py
diff --git
a/gpt-test.py
b/gpt-test.py
index
ff72e50
..
0967043
100755
(executable)
--- a/
gpt-test.py
+++ b/
gpt-test.py
@@
-22,7
+22,9
@@
from transformers import GPT2Tokenizer, GPT2LMHeadModel
######################################################################
######################################################################
-def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None):
+def complete(model, tokenizer,
+ primer,
+ nb_sentences = 1, nb_token_max = 100, temperature = None):
nt, ns = 0, 0
tokens = tokenizer.encode(primer)
primer_len = len(tokens)
nt, ns = 0, 0
tokens = tokenizer.encode(primer)
primer_len = len(tokens)
@@
-51,8
+53,10
@@
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
+print(f'Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)')
+
print(
print(
- complete(model,
+ complete(model,
tokenizer,
'The object was blue all over, but also green all over, it was a',
)
)
'The object was blue all over, but also green all over, it was a',
)
)