X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=gpt-test.py;h=096704334c2e89e4af94c28afefccf757fc91f3d;hb=7cf4aeca5c3e8feb279752d98a5ab7568bd0602f;hp=ff72e5076c7695584420bdd0a4cbac952daa8a9f;hpb=a2bf298d8b609810bbb4caaabf6633deee768481;p=pytorch.git diff --git a/gpt-test.py b/gpt-test.py index ff72e50..0967043 100755 --- a/gpt-test.py +++ b/gpt-test.py @@ -22,7 +22,9 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel ###################################################################### -def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None): +def complete(model, tokenizer, + primer, + nb_sentences = 1, nb_token_max = 100, temperature = None): nt, ns = 0, 0 tokens = tokenizer.encode(primer) primer_len = len(tokens) @@ -51,8 +53,10 @@ tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model.eval() +print(f'Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)') + print( - complete(model, + complete(model, tokenizer, 'The object was blue all over, but also green all over, it was a', ) )