OCDC
[mygpt.git] / mygpt.py
index 4951460..212e1a5 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -36,37 +36,49 @@ class PositionalEncoding(nn.Module):
         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
         k = j%2
-        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+        return x + pe
 
 ##############################
 
 class QKVAttention(nn.Module):
-    def __init__(self, dim_in, dim_qk, dim_v,
+    def __init__(self,
+                 dim_in, dim_qk, dim_v,
                  nb_heads = 1, causal = False, attention_dropout = 0.0):
         super().__init__()
 
         def randw(*d):
-            return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.causal = causal
+        self.attention_dropout = attention_dropout
 
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
         self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.causal = causal
-        self.attention_dropout = attention_dropout
+        self.w_o = randw(dim_v * nb_heads, dim_in)
 
     def forward(self, x_q, x_kv = None):
         if x_kv is None: x_kv = x_q
+
         q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
         k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
         v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
+
         a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
+
         if self.causal:
-            mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
+            mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
+                   < torch.arange(a.size(3), device = q.device)[None, None, None, :]
             a = a.masked_fill(mask, float('-inf'))
+
         a = a.softmax(dim = 3)
         a = F.dropout(a, self.attention_dropout, self.training)
-        y = torch.einsum('nhts,nhsd->nhtd', a, v)
-        return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
+        y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2)
+
+        y = y @ self.w_o
+
+        return y
 
 ##############################
 
@@ -74,7 +86,8 @@ class MyGPT(nn.Module):
     def __init__(self,
                  vocabulary_size,
                  dim_model, dim_keys, dim_hidden,
-                 nb_heads, nb_blocks, dropout = 0.):
+                 nb_heads, nb_blocks,
+                 dropout = 0.0, len_max = 1e5):
 
         super().__init__()
 
@@ -83,7 +96,7 @@ class MyGPT(nn.Module):
         self.embedding = nn.Sequential(
             nn.Embedding(vocabulary_size, dim_model),
             nn.Dropout(dropout),
-            PositionalEncoding(len_max = 1e5),
+            PositionalEncoding(len_max),
         )
 
         trunk_blocks = [ ]
@@ -94,11 +107,11 @@ class MyGPT(nn.Module):
                     nn.LayerNorm(dim_model),
                     QKVAttention(
                         dim_in = dim_model,
-                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
+                        dim_qk = dim_keys,
+                        dim_v = dim_model // nb_heads,
                         nb_heads = nb_heads,
                         causal = True, attention_dropout = dropout
                     ),
-                    nn.Linear(in_features = dim_model, out_features = dim_model),
                 ),
                 Residual(
                     nn.LayerNorm(dim_model),
@@ -114,20 +127,23 @@ class MyGPT(nn.Module):
         self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
 
     def forward(self, x):
+        x = F.pad(x, (1, 0))
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
-        return x
+        return x[:, :-1]
 
 ######################################################################
 
 if __name__ == '__main__':
+    print('Basic check.')
+
     vocabulary_size = 10
     x = torch.randint(vocabulary_size, (25, 100))
 
     model = MyGPT(
         vocabulary_size = vocabulary_size,
-        dim_model = 16, dim_keys = 50, dim_hidden = 100,
+        dim_model = 18, dim_keys = 50, dim_hidden = 100,
         nb_heads = 2, nb_blocks = 3,
         dropout = 0.1
     )