From a4145c0493bf53f1d076f98d1ecc36cebf36478c Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 25 Jul 2022 21:04:30 +0200 Subject: [PATCH] OCD update --- mygpt.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/mygpt.py b/mygpt.py index 57cbbc6..37fe6af 100755 --- a/mygpt.py +++ b/mygpt.py @@ -41,34 +41,43 @@ class PositionalEncoding(nn.Module): ############################## class QKVAttention(nn.Module): - def __init__(self, dim_in, dim_qk, dim_v, - nb_heads = 1, causal = False, attention_dropout = 0.0): + def __init__( + self, + dim_in, dim_qk, dim_v, + nb_heads = 1, causal = False, attention_dropout = 0.0 + ): super().__init__() def randw(*d): - return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1]))) + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.causal = causal + self.attention_dropout = attention_dropout self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) self.w_o = randw(dim_in, dim_v * nb_heads) - self.causal = causal - self.attention_dropout = attention_dropout def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q + q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + if self.causal: mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ < torch.arange(a.size(3), device = q.device)[None, None, None, :] a = a.masked_fill(mask, float('-inf')) + a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nthd', a, v) - y = y.flatten(2) @ self.w_o + y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + + y = y @ self.w_o return y -- 2.20.1