From 38c9162209ddc1894da6805a3c7459d8c2b3a13d Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 13:27:58 +0200 Subject: [PATCH] Added a null token, which is the one to predict. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 4a332b8..e973291 100755 --- a/main.py +++ b/main.py @@ -204,7 +204,7 @@ class TaskPicoCLVR(Task): t_generated = [ ] for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ] input = torch.tensor(t, device = self.device) output = model(input) logits = output[0, -1] -- 2.20.1