Added the small-weight embedding initialization.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 7 Aug 2022 19:50:15 +0000 (21:50 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 7 Aug 2022 19:50:15 +0000 (21:50 +0200)
mygpt.py

index 7ff1035..3bce361 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -97,6 +97,10 @@ class MyGPT(nn.Module):
             AddPositionalEncoding(len_max),
         )
 
+        # Small embedding initialization
+        with torch.no_grad():
+            self.embedding[0].weight.normal_(0, 2e-2)
+
         trunk_blocks = [ ]
 
         for _ in range(nb_blocks):