OCDC
authorFrancois Fleuret <francois@fleuret.org>
Thu, 28 Jul 2022 19:53:21 +0000 (21:53 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 28 Jul 2022 19:53:21 +0000 (21:53 +0200)
mygpt.py

index 7c4e06d..212e1a5 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -37,16 +37,14 @@ class PositionalEncoding(nn.Module):
         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
         k = j%2
         pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
-        return x + pe # Let broadcasting to its job
+        return x + pe
 
 ##############################
 
 class QKVAttention(nn.Module):
-    def __init__(
-            self,
-            dim_in, dim_qk, dim_v,
-            nb_heads = 1, causal = False, attention_dropout = 0.0
-    ):
+    def __init__(self,
+                 dim_in, dim_qk, dim_v,
+                 nb_heads = 1, causal = False, attention_dropout = 0.0):
         super().__init__()
 
         def randw(*d):
@@ -88,7 +86,8 @@ class MyGPT(nn.Module):
     def __init__(self,
                  vocabulary_size,
                  dim_model, dim_keys, dim_hidden,
-                 nb_heads, nb_blocks, dropout = 0.):
+                 nb_heads, nb_blocks,
+                 dropout = 0.0, len_max = 1e5):
 
         super().__init__()
 
@@ -97,7 +96,7 @@ class MyGPT(nn.Module):
         self.embedding = nn.Sequential(
             nn.Embedding(vocabulary_size, dim_model),
             nn.Dropout(dropout),
-            PositionalEncoding(len_max = 1e5),
+            PositionalEncoding(len_max),
         )
 
         trunk_blocks = [ ]