From: Francois Fleuret Date: Mon, 25 Jul 2022 16:14:57 +0000 (+0200) Subject: Initialize properly w_o. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=ceda7771b579aa3fb21115c6e71975d3cb7583bd Initialize properly w_o. --- diff --git a/mygpt.py b/mygpt.py index 42960b1..57cbbc6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -51,7 +51,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(nb_heads, dim_in, dim_v) + self.w_o = randw(dim_in, dim_v * nb_heads) self.causal = causal self.attention_dropout = attention_dropout @@ -67,8 +67,8 @@ class QKVAttention(nn.Module): a = a.masked_fill(mask, float('-inf')) a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nhtd', a, v) - y = torch.einsum('nhtd,hcd->ntc', y, self.w_o) + y = torch.einsum('nhts,nhsd->nthd', a, v) + y = y.flatten(2) @ self.w_o return y