3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 # You need to install PyTorch
11 # https://pytorch.org/get-started/locally/
13 # and Huggingface's transformers (which include pre-trained GPT
16 # pip install transformers
21 from transformers import GPT2Tokenizer, GPT2LMHeadModel
23 ######################################################################
27 model, tokenizer, primer, nb_sentences=1, nb_token_max=100, temperature=None
30 tokens = tokenizer.encode(primer)
31 primer_len = len(tokens)
33 outputs = model(torch.tensor([tokens])).logits
34 if temperature is None:
35 next_token = torch.argmax(outputs[0, -1])
37 dist = torch.distributions.Categorical(logits=outputs[0, -1] / temperature)
38 next_token = dist.sample((1,)).item()
40 tokens.append(next_token)
42 if tokenizer.decode([next_token]) == ".":
44 if ns == nb_sentences or nt == nb_token_max:
47 + tokenizer.decode(tokens[:primer_len])
49 + tokenizer.decode(tokens[primer_len:])
53 ######################################################################
56 # model_name = 'gpt2-large'
57 model_name = "gpt2-xl"
59 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
60 model = GPT2LMHeadModel.from_pretrained(model_name)
64 f"Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)"
71 "The object was blue all over, but also green all over, it was a",
75 ######################################################################