Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 07:54:04 +0000 (08:54 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 18 Jan 2024 07:54:04 +0000 (08:54 +0100)
fridge
mygpt.py

diff --git a/fridge b/fridge
index d09e92d..2cc6d01 100644 (file)
--- a/fridge
+++ b/fridge
@@ -292,3 +292,13 @@ class Calibrator:
         # A = har / (har + 1)
         # G = G / har
 
+
+######################################################################
+
+2024 Jan 18 08:46:18 (from mygpt.py)
+
+        # warnings.warn("softmax gating", RuntimeWarning)
+
+        # G = (
+        # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
+        # ).softmax(dim=2)
index a27b99e..492a9bb 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -597,36 +597,14 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
-        # warnings.warn("softmax gating", RuntimeWarning)
+        # Clip the gating to avoid values greater than 1 when several
+        # heads hit the same row
 
-        # G = (
-        # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
-        # ).softmax(dim=2)
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
 
         ######################################################################
-        # The "flashbacks"
-
-        if self.training and self.proba_gate_dropout > 0.0:
-            # This is a better implementation of "flashbacks".
-
-            # G is NxHxExT where e is the caterpillar's row.
-
-            warnings.warn("gate dropout", RuntimeWarning)
-
-            kill = (
-                torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
-            ).float()
-
-            alpha = G / (1 - self.proba_gate_dropout)
-
-            G = alpha * (1 - kill)
 
         def recurrence(G, V, K):
-            # Clip the gating to avoid values greater than 1 when several
-            # heads hit the same row
-
-            G = G / G.sum(1, keepdim=True).clamp(min=1)
-
             # We prepare the arguments for the parallel scan
 
             A = 1 - G.sum(1)
@@ -663,6 +641,26 @@ class Caterpillar(nn.Module):
 
         next_V, next_K = recurrence(G, V, K)
 
+        if self.training and self.proba_gate_dropout > 0.0:
+            # G is NxHxRxT where r is the caterpillar's row.
+
+            warnings.warn("gate dropout", RuntimeWarning)
+
+            kill = (
+                torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+            ).float()
+
+            mask = 1 - kill
+
+            masked_next_V, masked_next_K = recurrence(G * mask, V, K)
+
+            next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
+                1 - self.proba_gate_dropout
+            )
+            next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
+                1 - self.proba_gate_dropout
+            )
+
         self.rec_V[:, :, t0:t1] = next_V
         self.rec_K[:, :, t0:t1] = next_K