From ceda7771b579aa3fb21115c6e71975d3cb7583bd Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 25 Jul 2022 18:14:57 +0200 Subject: [PATCH] Initialize properly w_o. --- mygpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mygpt.py b/mygpt.py index 42960b1..57cbbc6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -51,7 +51,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(nb_heads, dim_in, dim_v) + self.w_o = randw(dim_in, dim_v * nb_heads) self.causal = causal self.attention_dropout = attention_dropout @@ -67,8 +67,8 @@ class QKVAttention(nn.Module): a = a.masked_fill(mask, float('-inf')) a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nhtd', a, v) - y = torch.einsum('nhtd,hcd->ntc', y, self.w_o) + y = torch.einsum('nhts,nhsd->nthd', a, v) + y = y.flatten(2) @ self.w_o return y -- 2.20.1