Fixed the size of w_o.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 15:06:13 +0000 (17:06 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 26 Jul 2022 15:06:13 +0000 (17:06 +0200)
mygpt.py

index 121ad80..ab16e1e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -57,7 +57,7 @@ class QKVAttention(nn.Module):
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
         self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.w_o = randw(dim_in, dim_v * nb_heads)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
 
     def forward(self, x_q, x_kv = None):
         if x_kv is None: x_kv = x_q
@@ -142,7 +142,7 @@ if __name__ == '__main__':
 
     model = MyGPT(
         vocabulary_size = vocabulary_size,
-        dim_model = 16, dim_keys = 50, dim_hidden = 100,
+        dim_model = 18, dim_keys = 50, dim_hidden = 100,
         nb_heads = 2, nb_blocks = 3,
         dropout = 0.1
     )