Update.
[pytorch.git] / gpt-test.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5 #
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 #
9 # You need to install PyTorch
10 #
11 #   https://pytorch.org/get-started/locally/
12 #
13 # and Huggingface's transformers (which include pre-trained GPT
14 # models)
15 #
16 #  pip install transformers
17 #
18
19 import torch
20
21 from transformers import GPT2Tokenizer, GPT2LMHeadModel
22
23 ######################################################################
24
25
26 def complete(
27     model, tokenizer, primer, nb_sentences=1, nb_token_max=100, temperature=None
28 ):
29     nt, ns = 0, 0
30     tokens = tokenizer.encode(primer)
31     primer_len = len(tokens)
32     while True:
33         outputs = model(torch.tensor([tokens])).logits
34         if temperature is None:
35             next_token = torch.argmax(outputs[0, -1])
36         else:
37             dist = torch.distributions.Categorical(logits=outputs[0, -1] / temperature)
38             next_token = dist.sample((1,)).item()
39
40         tokens.append(next_token)
41         nt += 1
42         if tokenizer.decode([next_token]) == ".":
43             ns += 1
44         if ns == nb_sentences or nt == nb_token_max:
45             return (
46                 "<"
47                 + tokenizer.decode(tokens[:primer_len])
48                 + ">"
49                 + tokenizer.decode(tokens[primer_len:])
50             )
51
52
53 ######################################################################
54
55 # model_name = 'gpt2'
56 # model_name = 'gpt2-large'
57 model_name = "gpt2-xl"
58
59 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
60 model = GPT2LMHeadModel.from_pretrained(model_name)
61 model.eval()
62
63 print(
64     f"Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)"
65 )
66
67 print(
68     complete(
69         model,
70         tokenizer,
71         "The object was blue all over, but also green all over, it was a",
72     )
73 )
74
75 ######################################################################