Initialize properly w_o.
authorFrancois Fleuret <francois@fleuret.org>
Mon, 25 Jul 2022 16:14:57 +0000 (18:14 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 25 Jul 2022 16:14:57 +0000 (18:14 +0200)
mygpt.py

index 42960b1..57cbbc6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -51,7 +51,7 @@ class QKVAttention(nn.Module):
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
         self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.w_o = randw(nb_heads, dim_in, dim_v)
+        self.w_o = randw(dim_in, dim_v * nb_heads)
         self.causal = causal
         self.attention_dropout = attention_dropout
 
@@ -67,8 +67,8 @@ class QKVAttention(nn.Module):
             a = a.masked_fill(mask, float('-inf'))
         a = a.softmax(dim = 3)
         a = F.dropout(a, self.attention_dropout, self.training)
-        y = torch.einsum('nhts,nhsd->nhtd', a, v)
-        y = torch.einsum('nhtd,hcd->ntc', y, self.w_o)
+        y = torch.einsum('nhts,nhsd->nthd', a, v)
+        y = y.flatten(2) @ self.w_o
 
         return y