Added the small weight embedding + id layer norm inits.
[mygpt.git] / mygpt.py
index 121ad80..ebc9a83 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -14,7 +14,7 @@ from torch.nn import functional as F
 
 ##############################
 
-class Residual(nn.Module):
+class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
@@ -24,28 +24,25 @@ class Residual(nn.Module):
 
 ##############################
 
-class PositionalEncoding(nn.Module):
+class AddPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()
         self.len_max = len_max
 
-    # From Vaswani et al 2018
-    # PE_{t,2i}   = sin(t/(L^{2i/D}))
-    # PE_{t,2i+1} = cos(t/(L^{2i/D}))
+    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
     def forward(self, x):
         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
         k = j%2
-        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+        return x + pe
 
 ##############################
 
 class QKVAttention(nn.Module):
-    def __init__(
-            self,
-            dim_in, dim_qk, dim_v,
-            nb_heads = 1, causal = False, attention_dropout = 0.0
-    ):
+    def __init__(self,
+                 dim_in, dim_qk, dim_v,
+                 nb_heads = 1, causal = False, attention_dropout = 0.0):
         super().__init__()
 
         def randw(*d):
@@ -57,7 +54,7 @@ class QKVAttention(nn.Module):
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
         self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.w_o = randw(dim_in, dim_v * nb_heads)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
 
     def forward(self, x_q, x_kv = None):
         if x_kv is None: x_kv = x_q
@@ -87,7 +84,8 @@ class MyGPT(nn.Module):
     def __init__(self,
                  vocabulary_size,
                  dim_model, dim_keys, dim_hidden,
-                 nb_heads, nb_blocks, dropout = 0.):
+                 nb_heads, nb_blocks,
+                 dropout = 0.0, len_max = 1e5):
 
         super().__init__()
 
@@ -96,24 +94,25 @@ class MyGPT(nn.Module):
         self.embedding = nn.Sequential(
             nn.Embedding(vocabulary_size, dim_model),
             nn.Dropout(dropout),
-            PositionalEncoding(len_max = 1e5),
+            AddPositionalEncoding(len_max),
         )
 
         trunk_blocks = [ ]
 
         for _ in range(nb_blocks):
             trunk_blocks += [
-                Residual(
-                    nn.LayerNorm(dim_model),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
                     QKVAttention(
                         dim_in = dim_model,
-                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
+                        dim_qk = dim_keys,
+                        dim_v = dim_model // nb_heads,
                         nb_heads = nb_heads,
                         causal = True, attention_dropout = dropout
                     ),
                 ),
-                Residual(
-                    nn.LayerNorm(dim_model),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
                     nn.Linear(in_features = dim_model, out_features = dim_hidden),
                     nn.ReLU(),
                     nn.Linear(in_features = dim_hidden, out_features = dim_model),
@@ -125,12 +124,20 @@ class MyGPT(nn.Module):
 
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean = 0, std = 2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
     def forward(self, x):
-        x = F.pad(x, (1, 0))
+        x = F.pad(x, (1, -1))
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
-        return x[:, :-1]
+        return x
 
 ######################################################################
 
@@ -142,7 +149,7 @@ if __name__ == '__main__':
 
     model = MyGPT(
         vocabulary_size = vocabulary_size,
-        dim_model = 16, dim_keys = 50, dim_hidden = 100,
+        dim_model = 18, dim_keys = 50, dim_hidden = 100,
         nb_heads = 2, nb_blocks = 3,
         dropout = 0.1
     )