3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 # You need to install PyTorch
11 # https://pytorch.org/get-started/locally/
13 # and Huggingface's transformers (which include pre-trained GPT
16 # pip install transformers
21 from transformers import GPT2Tokenizer, GPT2LMHeadModel
23 ######################################################################
25 def complete(model, tokenizer,
27 nb_sentences = 1, nb_token_max = 100, temperature = None):
29 tokens = tokenizer.encode(primer)
30 primer_len = len(tokens)
32 outputs = model(torch.tensor([tokens])).logits
33 if temperature is None:
34 next_token = torch.argmax(outputs[0, -1])
36 dist = torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
37 next_token = dist.sample((1,)).item()
39 tokens.append(next_token)
41 if tokenizer.decode([next_token]) == '.': ns += 1
42 if ns == nb_sentences or nt == nb_token_max:
43 return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
44 tokenizer.decode(tokens[primer_len:])
46 ######################################################################
49 #model_name = 'gpt2-large'
50 model_name = 'gpt2-xl'
52 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
53 model = GPT2LMHeadModel.from_pretrained(model_name)
56 print(f'Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)')
59 complete(model, tokenizer,
60 'The object was blue all over, but also green all over, it was a',
64 ######################################################################