From 9c62741f73c7bbcd00bafad84cd31325b358ef1d Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 25 Jul 2022 15:31:18 +0200 Subject: [PATCH] Added the (missing) W_o --- mygpt.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mygpt.py b/mygpt.py index 4951460..42960b1 100755 --- a/mygpt.py +++ b/mygpt.py @@ -51,6 +51,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(nb_heads, dim_in, dim_v) self.causal = causal self.attention_dropout = attention_dropout @@ -61,12 +62,15 @@ class QKVAttention(nn.Module): v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) if self.causal: - mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0 + mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ + < torch.arange(a.size(3), device = q.device)[None, None, None, :] a = a.masked_fill(mask, float('-inf')) a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) y = torch.einsum('nhts,nhsd->nhtd', a, v) - return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd) + y = torch.einsum('nhtd,hcd->ntc', y, self.w_o) + + return y ############################## @@ -122,6 +126,8 @@ class MyGPT(nn.Module): ###################################################################### if __name__ == '__main__': + print('Basic check.') + vocabulary_size = 10 x = torch.randint(vocabulary_size, (25, 100)) -- 2.20.1