X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=3bce361424ec4a8b37150c300dc0df05b01de7ef;hb=c3621f9a75cd4d79410d90a29dc9fdec401eaa2d;hp=d6879dc08a29f05cac1998bc1ab16e46db07821c;hpb=52c6bd98650c846459f10e8303dd2e6c7ba2a68f;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index d6879dc..3bce361 100755 --- a/mygpt.py +++ b/mygpt.py @@ -14,7 +14,7 @@ from torch.nn import functional as F ############################## -class Residual(nn.Module): +class WithResidual(nn.Module): def __init__(self, *f): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) @@ -24,14 +24,12 @@ class Residual(nn.Module): ############################## -class PositionalEncoding(nn.Module): +class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max - # From Vaswani et al 2018 - # PE_{t,2i} = sin(t/(L^{2i/D})) - # PE_{t,2i+1} = cos(t/(L^{2i/D})) + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] @@ -96,14 +94,18 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max), + AddPositionalEncoding(len_max), ) + # Small embedding initialization + with torch.no_grad(): + self.embedding[0].weight.normal_(0, 2e-2) + trunk_blocks = [ ] for _ in range(nb_blocks): trunk_blocks += [ - Residual( + WithResidual( nn.LayerNorm((dim_model,)), QKVAttention( dim_in = dim_model, @@ -113,7 +115,7 @@ class MyGPT(nn.Module): causal = True, attention_dropout = dropout ), ), - Residual( + WithResidual( nn.LayerNorm((dim_model,)), nn.Linear(in_features = dim_model, out_features = dim_hidden), nn.ReLU(), @@ -127,11 +129,10 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): - x = F.pad(x, (1, 0)) + x = F.pad(x, (1, -1)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - x = F.pad(x, (0, 0, 0, -1)) return x ######################################################################