Initial commit.
[pytorch.git] / gpt-test.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5 #
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 #
9 # You need to install PyTorch
10 #
11 #   https://pytorch.org/get-started/locally/
12 #
13 # and Huggingface's transformers (which include pre-trained GPT
14 # models)
15 #
16 #  pip install transformers
17 #
18
19 import torch
20
21 from transformers import GPT2Tokenizer, GPT2LMHeadModel
22
23 ######################################################################
24
25 def complete(model, primer, nb_sentences = 1, nb_token_max = 100, temperature = None):
26     nt, ns = 0, 0
27     tokens = tokenizer.encode(primer)
28     primer_len = len(tokens)
29     while True:
30         outputs = model(torch.tensor([tokens])).logits
31         if temperature is None:
32             next_token = torch.argmax(outputs[0, -1])
33         else:
34             dist =  torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
35             next_token = dist.sample((1,)).item()
36
37         tokens.append(next_token)
38         nt += 1
39         if tokenizer.decode([next_token]) == '.': ns += 1
40         if ns == nb_sentences or nt == nb_token_max:
41             return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
42                 tokenizer.decode(tokens[primer_len:])
43
44 ######################################################################
45
46 #model_name = 'gpt2'
47 #model_name = 'gpt2-large'
48 model_name = 'gpt2-xl'
49
50 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
51 model = GPT2LMHeadModel.from_pretrained(model_name)
52 model.eval()
53
54 print(
55     complete(model,
56              'The object was blue all over, but also green all over, it was a',
57     )
58 )
59
60 ######################################################################