Update.
[pytorch.git] / gpt-test.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5 #
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 #
9 # You need to install PyTorch
10 #
11 #   https://pytorch.org/get-started/locally/
12 #
13 # and Huggingface's transformers (which include pre-trained GPT
14 # models)
15 #
16 #  pip install transformers
17 #
18
19 import torch
20
21 from transformers import GPT2Tokenizer, GPT2LMHeadModel
22
23 ######################################################################
24
25 def complete(model, tokenizer,
26              primer,
27              nb_sentences = 1, nb_token_max = 100, temperature = None):
28     nt, ns = 0, 0
29     tokens = tokenizer.encode(primer)
30     primer_len = len(tokens)
31     while True:
32         outputs = model(torch.tensor([tokens])).logits
33         if temperature is None:
34             next_token = torch.argmax(outputs[0, -1])
35         else:
36             dist =  torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
37             next_token = dist.sample((1,)).item()
38
39         tokens.append(next_token)
40         nt += 1
41         if tokenizer.decode([next_token]) == '.': ns += 1
42         if ns == nb_sentences or nt == nb_token_max:
43             return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
44                 tokenizer.decode(tokens[primer_len:])
45
46 ######################################################################
47
48 #model_name = 'gpt2'
49 #model_name = 'gpt2-large'
50 model_name = 'gpt2-xl'
51
52 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
53 model = GPT2LMHeadModel.from_pretrained(model_name)
54 model.eval()
55
56 print(f'Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)')
57
58 print(
59     complete(model, tokenizer,
60              'The object was blue all over, but also green all over, it was a',
61     )
62 )
63
64 ######################################################################