Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 07:11:40 +0000 (08:11 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 10 Jan 2024 07:11:40 +0000 (08:11 +0100)
fridge
mygpt.py

diff --git a/fridge b/fridge
index d28cc89..bb6f46e 100644 (file)
--- a/fridge
+++ b/fridge
@@ -117,3 +117,11 @@ def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba):
                 # mask * K[n, src_head, src_time, dk]
                 # + (1 - mask) * self.rec_K[:, :, t0:t1]
             # )
+
+######################################################################
+
+2024 Jan 10 08:10:39 (from mygpt.py)
+
+        # That was a bad idea
+        # G = F.dropout(G, self.attention_dropout, self.training)
+
index 95e5527..17f2f6d 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -545,42 +545,45 @@ class Caterpillar(nn.Module):
 
         # This is the Gating sequence that modulates the storing of
         # the new key and value in the CH pairs of the current
-        # stack. The CH gating values are independent, which means
-        # that the current K/V could be stored in multiple pairs of the
+        # stack. There are CH independent gating values, which means
+        # that the current K/V may be stored in multiple pairs of the
         # recurrent state, or not at all.
 
         G = (
             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        # Clip the gating to avoid values greater than 1 when several
+        # heads hit the same row
+
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
+
         if self.training and self.proba_gate_dropout > 0.0:
-            warnings.warn("gate droupout", RuntimeWarning)
+            warnings.warn("gate dropout", RuntimeWarning)
             epsilon = 0.5
 
-        # That was a bad idea
-        # G = F.dropout(G, self.attention_dropout, self.training)
-
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
 
         # We prepare the arguments for the parallel scan
 
-        # Clip the gating
-        warnings.warn("gating clipping", RuntimeWarning)
-        G = G / G.sum(1, keepdim=True).clamp(min=1)
-
         A = 1 - G.sum(1)
         gated_V = torch.einsum("nhet,nhtd->netd", G, V)
         gated_K = torch.einsum("nhet,nhtd->netd", G, K)
 
+        # Initial recurrent state
+
         init_rec_V = self.rec_V[:, :, t0 - CL : t0]
         init_rec_K = self.rec_K[:, :, t0 - CL : t0]
 
-        # Here there is a trick: Since the stack at time t is computed
-        # by updating that at time t-L, the parallel scan operates
-        # with a period of L. To do so we split the time indexing in
-        # two axes, the second of size CL, and run the parallel scan
-        # using the other as the sequence index.
+        #################################################################
+        # Associative scan
+
+        # Here there is a trick: Since the stack at position t is
+        # computed by updating that at position t-CL, the parallel
+        # scan operates with a period of CL. To do so we split the
+        # sequence indexing in two axes, the second of size CL, and
+        # run the parallel scan using the first as the sequence index.
 
         A = A.unflatten(2, (-1, CL))
         gated_V = gated_V.unflatten(2, (-1, CL))
@@ -589,11 +592,11 @@ class Caterpillar(nn.Module):
         next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
         next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
 
-        # Put back the sequence index
-
         self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
         self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
 
+        #################################################################
+
         if self.training and self.proba_flashback > 0.0:
             warnings.warn("flash back", RuntimeWarning)
             # This piece of code makes the assumption that there is