Added the rng state in the checkpoint.
[mygpt.git] / mygpt.py
index 954f4f0..3bce361 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -97,6 +97,10 @@ class MyGPT(nn.Module):
             AddPositionalEncoding(len_max),
         )
 
+        # Small embedding initialization
+        with torch.no_grad():
+            self.embedding[0].weight.normal_(0, 2e-2)
+
         trunk_blocks = [ ]
 
         for _ in range(nb_blocks):
@@ -125,11 +129,10 @@ class MyGPT(nn.Module):
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
     def forward(self, x):
-        x = F.pad(x, (1, 0))
+        x = F.pad(x, (1, -1))
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
-        x = F.pad(x, (0, 0, 0, -1))
         return x
 
 ######################################################################