Added the small-weight embedding initialization.
[mygpt.git] / mygpt.py
index 7ff1035..3bce361 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -97,6 +97,10 @@ class MyGPT(nn.Module):
             AddPositionalEncoding(len_max),
         )
 
+        # Small embedding initialization
+        with torch.no_grad():
+            self.embedding[0].weight.normal_(0, 2e-2)
+
         trunk_blocks = [ ]
 
         for _ in range(nb_blocks):