From: Francois Fleuret Date: Tue, 26 Jul 2022 15:06:13 +0000 (+0200) Subject: Fixed the size of w_o. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=12184f604b37f36f07d7dcdd567b1c78f02c74db Fixed the size of w_o. --- diff --git a/mygpt.py b/mygpt.py index 121ad80..ab16e1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -57,7 +57,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_in, dim_v * nb_heads) + self.w_o = randw(dim_v * nb_heads, dim_in) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q @@ -142,7 +142,7 @@ if __name__ == '__main__': model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 )