From a1a7cb9e680378db521f2a1e2139db0e2db903de Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 15:49:54 +0200 Subject: [PATCH] Update. --- main.py | 4 +++- mygpt.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index e973291..a83107b 100755 --- a/main.py +++ b/main.py @@ -204,8 +204,9 @@ class TaskPicoCLVR(Task): t_generated = [ ] for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ] + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] input = torch.tensor(t, device = self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: @@ -333,6 +334,7 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: diff --git a/mygpt.py b/mygpt.py index 5370ffa..121ad80 100755 --- a/mygpt.py +++ b/mygpt.py @@ -126,7 +126,7 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): - x = torch.cat((x.new_zeros(x.size(0), 1), x), 1) + x = F.pad(x, (1, 0)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) -- 2.20.1