X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=7ff10358e77cce589ca9d1d53a5a5682ebb2e451;hb=c0019b5af155be6a8af02bf71a62c43af1d7a178;hp=57cbbc6b23d66404e374851dedf4a35fa4c85852;hpb=ceda7771b579aa3fb21115c6e71975d3cb7583bd;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 57cbbc6..7ff1035 100755 --- a/mygpt.py +++ b/mygpt.py @@ -14,7 +14,7 @@ from torch.nn import functional as F ############################## -class Residual(nn.Module): +class WithResidual(nn.Module): def __init__(self, *f): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) @@ -24,51 +24,57 @@ class Residual(nn.Module): ############################## -class PositionalEncoding(nn.Module): +class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max - # From Vaswani et al 2018 - # PE_{t,2i} = sin(t/(L^{2i/D})) - # PE_{t,2i+1} = cos(t/(L^{2i/D})) + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] k = j%2 - return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :] + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + return x + pe ############################## class QKVAttention(nn.Module): - def __init__(self, dim_in, dim_qk, dim_v, + def __init__(self, + dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0): super().__init__() def randw(*d): - return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1]))) + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.causal = causal + self.attention_dropout = attention_dropout self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_in, dim_v * nb_heads) - self.causal = causal - self.attention_dropout = attention_dropout + self.w_o = randw(dim_v * nb_heads, dim_in) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q + q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + if self.causal: mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ < torch.arange(a.size(3), device = q.device)[None, None, None, :] a = a.masked_fill(mask, float('-inf')) + a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nthd', a, v) - y = y.flatten(2) @ self.w_o + y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + + y = y @ self.w_o return y @@ -78,7 +84,8 @@ class MyGPT(nn.Module): def __init__(self, vocabulary_size, dim_model, dim_keys, dim_hidden, - nb_heads, nb_blocks, dropout = 0.): + nb_heads, nb_blocks, + dropout = 0.0, len_max = 1e5): super().__init__() @@ -87,25 +94,25 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max = 1e5), + AddPositionalEncoding(len_max), ) trunk_blocks = [ ] for _ in range(nb_blocks): trunk_blocks += [ - Residual( - nn.LayerNorm(dim_model), + WithResidual( + nn.LayerNorm((dim_model,)), QKVAttention( dim_in = dim_model, - dim_qk = dim_keys, dim_v = dim_model // nb_heads, + dim_qk = dim_keys, + dim_v = dim_model // nb_heads, nb_heads = nb_heads, causal = True, attention_dropout = dropout ), - nn.Linear(in_features = dim_model, out_features = dim_model), ), - Residual( - nn.LayerNorm(dim_model), + WithResidual( + nn.LayerNorm((dim_model,)), nn.Linear(in_features = dim_model, out_features = dim_hidden), nn.ReLU(), nn.Linear(in_features = dim_hidden, out_features = dim_model), @@ -118,6 +125,7 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): + x = F.pad(x, (1, -1)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) @@ -133,7 +141,7 @@ if __name__ == '__main__': model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 )