Update.
[mygpt.git] / mygpt.py
index 5370ffa..121ad80 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -126,7 +126,7 @@ class MyGPT(nn.Module):
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
     def forward(self, x):
-        x = torch.cat((x.new_zeros(x.size(0), 1), x), 1)
+        x = F.pad(x, (1, 0))
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)