X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=43711b3da39db0c7709b0613df0e3bd6a5f5153c;hb=0c51561334475af559cda12627388c9d5567a55f;hp=a23470b046faa4b4a6a2c0853c09c4c124a6679f;hpb=063e25c1e1442c406746a39220f3c3590882cf51;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index a23470b..43711b3 100755 --- a/mygpt.py +++ b/mygpt.py @@ -41,31 +41,45 @@ class PositionalEncoding(nn.Module): ############################## class QKVAttention(nn.Module): - def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0): + def __init__( + self, + dim_in, dim_qk, dim_v, + nb_heads = 1, causal = False, attention_dropout = 0.0 + ): super().__init__() def randw(*d): - return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1]))) + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.wq = randw(nb_heads, dim_qk, dim_in) - self.wk = randw(nb_heads, dim_qk, dim_in) - self.wv = randw(nb_heads, dim_v, dim_in) self.causal = causal self.attention_dropout = attention_dropout - def forward(self, x): - q = torch.einsum('ntc,hdc->nhtd', x, self.wq) - k = torch.einsum('ntc,hdc->nhtd', x, self.wk) - v = torch.einsum('ntc,hdc->nhtd', x, self.wv) - r = math.sqrt(q.size(3)) - a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r) + self.w_q = randw(nb_heads, dim_qk, dim_in) + self.w_k = randw(nb_heads, dim_qk, dim_in) + self.w_v = randw(nb_heads, dim_v, dim_in) + self.w_o = randw(dim_v * nb_heads, dim_in) + + def forward(self, x_q, x_kv = None): + if x_kv is None: x_kv = x_q + + q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) + k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) + v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + + a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + if self.causal: - mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0 + mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ + < torch.arange(a.size(3), device = q.device)[None, None, None, :] a = a.masked_fill(mask, float('-inf')) + a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nhtd', a, v) - return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd) + y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + + y = y @ self.w_o + + return y ############################## @@ -93,11 +107,11 @@ class MyGPT(nn.Module): nn.LayerNorm(dim_model), QKVAttention( dim_in = dim_model, - dim_qk = dim_keys, dim_v = dim_model // nb_heads, + dim_qk = dim_keys, + dim_v = dim_model // nb_heads, nb_heads = nb_heads, causal = True, attention_dropout = dropout ), - nn.Linear(in_features = dim_model, out_features = dim_model), ), Residual( nn.LayerNorm(dim_model), @@ -113,20 +127,23 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): + x = F.pad(x, (1, 0)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - return x + return x[:, :-1] ###################################################################### if __name__ == '__main__': + print('Basic check.') + vocabulary_size = 10 x = torch.randint(vocabulary_size, (25, 100)) model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 )