Update.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 13:49:54 +0000 (15:49 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 13:49:54 +0000 (15:49 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index e973291..a83107b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -204,8 +204,9 @@ class TaskPicoCLVR(Task):
         t_generated = [ ]
 
         for j in range(nb_tokens):
-            t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ]
+            t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
             input = torch.tensor(t, device = self.device)
+            input = F.pad(input, (0, 1)) # Add the next token, the one to predict
             output = model(input)
             logits = output[0, -1]
             if args.synthesis_sampling:
@@ -333,6 +334,7 @@ class TaskWiki103(Task):
                  for j in range(nb_tokens):
 
                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+                     input = F.pad(input, (0, 1)) # Add the next token, the one to predict
                      output = model(input)
                      logits = output[0, -1]
                      if args.synthesis_sampling:
index 5370ffa..121ad80 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -126,7 +126,7 @@ class MyGPT(nn.Module):
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
     def forward(self, x):
-        x = torch.cat((x.new_zeros(x.size(0), 1), x), 1)
+        x = F.pad(x, (1, 0))
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)