Removed the Linear transformation since there is now w_o.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 10:47:26 +0000 (12:47 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 10:47:26 +0000 (12:47 +0200)
mygpt.py

index 7f0c9e6..5370ffa 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -111,7 +111,6 @@ class MyGPT(nn.Module):
                         nb_heads = nb_heads,
                         causal = True, attention_dropout = dropout
                     ),
-                    nn.Linear(in_features = dim_model, out_features = dim_model),
                 ),
                 Residual(
                     nn.LayerNorm(dim_model),