Update.
[pytorch.git] / gpt-test.py
index 557f734..ddd7dcf 100755 (executable)
@@ -22,9 +22,10 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
 
 ######################################################################
 
-def complete(model, tokenizer,
-             primer,
-             nb_sentences = 1, nb_token_max = 100, temperature = None):
+
+def complete(
+    model, tokenizer, primer, nb_sentences=1, nb_token_max=100, temperature=None
+):
     nt, ns = 0, 0
     tokens = tokenizer.encode(primer)
     primer_len = len(tokens)
@@ -33,29 +34,41 @@ def complete(model, tokenizer,
         if temperature is None:
             next_token = torch.argmax(outputs[0, -1])
         else:
-            dist =  torch.distributions.Categorical(logits = outputs[0, -1] / temperature)
+            dist = torch.distributions.Categorical(logits=outputs[0, -1] / temperature)
             next_token = dist.sample((1,)).item()
 
         tokens.append(next_token)
         nt += 1
-        if tokenizer.decode([next_token]) == '.': ns += 1
+        if tokenizer.decode([next_token]) == ".":
+            ns += 1
         if ns == nb_sentences or nt == nb_token_max:
-            return '<' + tokenizer.decode(tokens[:primer_len]) + '>' + \
-                tokenizer.decode(tokens[primer_len:])
+            return (
+                "<"
+                + tokenizer.decode(tokens[:primer_len])
+                + ">"
+                + tokenizer.decode(tokens[primer_len:])
+            )
+
 
 ######################################################################
 
-#model_name = 'gpt2'
-#model_name = 'gpt2-large'
-model_name = 'gpt2-xl'
+# model_name = 'gpt2'
+# model_name = 'gpt2-large'
+model_name = "gpt2-xl"
 
 tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 model = GPT2LMHeadModel.from_pretrained(model_name)
 model.eval()
 
 print(
-    complete(model, tokenizer,
-             'The object was blue all over, but also green all over, it was a',
+    f"Using {model_name} ({int(sum(p.numel() for p in model.parameters())/(1e6))}M parameters)"
+)
+
+print(
+    complete(
+        model,
+        tokenizer,
+        "The object was blue all over, but also green all over, it was a",
     )
 )