+def compute_perplexity(model, split="train"):
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for input in task.batches(split=split):
+ input = input.to(device)
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_loss += loss.item() * input.size(0)
+ nb_samples += input.size(0)
+
+ model.train(t)
+
+ return math.exp(min(100, acc_loss / nb_samples))
+
+
+######################################################################
+
+
+def one_shot(gpt, task):
+ pass
+
+
+######################################################################
+
+