Update.
[mygptrnn.git] / mygpt.py
index d1acf22..c061eb4 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -441,6 +441,11 @@ class KVRec(nn.Module):
 ##############################
 
 
+# Returns a tensor with an additional index at rank win_dim, that move
+# along the same dimension as dim, on a domain {0...win_size-1}, and
+# dim is restricted on a domain reduced by win_size-1 values.
+
+
 def moving_window(x, dim, win_dim, win_size):
     size, stride = x.size(), x.stride()
     size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
@@ -530,16 +535,18 @@ class Caterpillar(nn.Module):
         ######################################################################
         # Compute the recurrent state
 
-        # This is the Gating sequence that modulates if they key and
-        # values should be stored in one of the CH pairs of the
-        # current stack. The CH gating values are independent, which
-        # means that the same thing could be stored up to CH times or
-        # not at all
+        # This is the Gating sequence that modulates the storing of
+        # the new key and value in the CH pairs of the current
+        # stack. The CH gating values are independent, which means
+        # that the current K/V could be stored in all the pairs of the
+        # recurrent state, or not at all.
 
         G = (
             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        G = F.dropout(G, self.attention_dropout, self.training)
+
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
 
@@ -552,10 +559,11 @@ class Caterpillar(nn.Module):
         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
 
-        # Here there is a trick: The parallel scan operates with a
-        # period of L, so we split the sequence indexing in two axes,
-        # the second of size CL, and run the parallel scan using the
-        # other alone as the sequence index.
+        # Here there is a trick: Since the stack at time t is computed
+        # by updating that at time t-L, the parallel scan operates
+        # with a period of L. To do so we split the time indexing in
+        # two axes, the second of size CL, and run the parallel scan
+        # using the other alone as the sequence index.
 
         A = A.unflatten(2, (-1, CL))
         gated_V = gated_V.unflatten(2, (-1, CL))