Update.
[mygptrnn.git] / mygpt.py
index a27b99e..492a9bb 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -597,36 +597,14 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
-        # warnings.warn("softmax gating", RuntimeWarning)
+        # Clip the gating to avoid values greater than 1 when several
+        # heads hit the same row
 
-        # G = (
-        # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
-        # ).softmax(dim=2)
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
 
         ######################################################################
-        # The "flashbacks"
-
-        if self.training and self.proba_gate_dropout > 0.0:
-            # This is a better implementation of "flashbacks".
-
-            # G is NxHxExT where e is the caterpillar's row.
-
-            warnings.warn("gate dropout", RuntimeWarning)
-
-            kill = (
-                torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
-            ).float()
-
-            alpha = G / (1 - self.proba_gate_dropout)
-
-            G = alpha * (1 - kill)
 
         def recurrence(G, V, K):
-            # Clip the gating to avoid values greater than 1 when several
-            # heads hit the same row
-
-            G = G / G.sum(1, keepdim=True).clamp(min=1)
-
             # We prepare the arguments for the parallel scan
 
             A = 1 - G.sum(1)
@@ -663,6 +641,26 @@ class Caterpillar(nn.Module):
 
         next_V, next_K = recurrence(G, V, K)
 
+        if self.training and self.proba_gate_dropout > 0.0:
+            # G is NxHxRxT where r is the caterpillar's row.
+
+            warnings.warn("gate dropout", RuntimeWarning)
+
+            kill = (
+                torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+            ).float()
+
+            mask = 1 - kill
+
+            masked_next_V, masked_next_K = recurrence(G * mask, V, K)
+
+            next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
+                1 - self.proba_gate_dropout
+            )
+            next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
+                1 - self.proba_gate_dropout
+            )
+
         self.rec_V[:, :, t0:t1] = next_V
         self.rec_K[:, :, t0:t1] = next_K