Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 16:58:03 +0000 (17:58 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 16:58:03 +0000 (17:58 +0100)
mygpt.py

index 7c8e9f4..ba93851 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -485,7 +485,7 @@ class Caterpillar(nn.Module):
         self.caterpillar_height = caterpillar_height
         self.attention_dropout = attention_dropout
 
-        self.proba_gate_dropout = 0.0
+        self.proba_gate_dropout = 0.25
 
         self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
         self.b_G = nn.Parameter(
@@ -572,7 +572,7 @@ class Caterpillar(nn.Module):
             warnings.warn("gate dropout", RuntimeWarning)
             epsilon = 0.5
 
-            dropout_start = (
+            dropout_head = (
                 (
                     torch.rand(G.size(), device=G.device)
                     .flatten(2, 3)
@@ -584,18 +584,18 @@ class Caterpillar(nn.Module):
                 .float()
             )
 
-            dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+            dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
 
             dropout_active = (
                 torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
             ).long()
 
-            dropout_start *= dropout_active
+            dropout_head *= dropout_active
             dropout_tail *= dropout_active
 
             G = (
                 G
-                + dropout_start * (1 - epsilon - G.detach())
+                # + dropout_head * (1 - epsilon - G.detach())
                 - dropout_tail * G.detach()
             )