Update.
[mygptrnn.git] / mygpt.py
index aded796..760a3c6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -126,7 +126,6 @@ class AddPositionalEncoding(nn.Module):
 
 import pscan
 
-
 # X is /.../xTxD   A is /.../xT   Y_init is /.../xD
 
 
@@ -147,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2):
     return Y
 
 
+def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2):
+    with torch.no_grad():
+        s_A, s_X = 0, 0
+        for t in range(X.size(dim) - 1, 0, -1):
+            delta = (grad_Y[t] - s_A) / A[t].grad
+            s_A += A[t].grad * delta
+            A[t].grad = delta
+            delta = (grad_Y[t] - s_X) / X[t].grad
+            s_X += X[t].grad * delta
+            X[t].grad = delta
+
+
 def pscan_shape(A, X, Y_init):
     s = X.size()
     A = A.reshape(-1, s[-2])
@@ -191,7 +202,7 @@ class DumbRec(nn.Module):
         attention_dropout=0.0,
         len_max=1e5,
         logger=print,
-        **kwargs,
+        args=None,
     ):
         super().__init__()
 
@@ -322,7 +333,7 @@ class KVRec(nn.Module):
         attention_dropout=0.0,
         len_max=1e5,
         logger=print,
-        **kwargs,
+        args=None,
     ):
         super().__init__()
 
@@ -464,36 +475,6 @@ def moving_window(x, dim, win_dim, win_size):
 ##############################
 
 
-class Calibrator:
-    def __init__(self, w=None, b=None):
-        self.w = w
-        self.b = b
-        self.s, self.s_sq, self.n = 0, 0, 0
-        self.mean, self.std = 0, 0
-
-    def update(self, X):
-        X = X.detach()
-        self.s += X.sum(dim=0)
-        self.s_sq += X.pow(2).sum(dim=0)
-        self.n += X.size(0)
-
-    def moments(self):
-        mean = self.s / self.n
-        std = (self.s_sq / self.n - mean * mean).sqrt()
-        return mean, std
-
-    def normalize(self):
-        mean, std = self.moments()
-        if self.b is not None:
-            self.b.sub_(mean)
-        if self.w is not None:
-            self.w.div_(std)
-        result = mean - self.mean, std - self.std
-        self.mean, self.std = mean, std
-        self.s, self.s_sq, self.n = 0, 0, 0
-        return result
-
-
 class Caterpillar(nn.Module):
     def __init__(
         self,
@@ -506,7 +487,7 @@ class Caterpillar(nn.Module):
         attention_dropout=0.0,
         len_max=1e5,
         logger=print,
-        **kwargs,
+        args=None,
     ):
         super().__init__()
 
@@ -521,27 +502,12 @@ class Caterpillar(nn.Module):
         self.caterpillar_height = caterpillar_height
         self.attention_dropout = attention_dropout
 
-        ######################################################################
-        # sup_args
-
-        x = kwargs.get("gate_dropout")
-        if x is None:
-            self.proba_gate_dropout = 0.0
-        else:
-            self.proba_gate_dropout = float(x)
-
-        logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
-
-        x = kwargs.get("default_bg")
-        if x is None:
-            default_bg = -math.log(caterpillar_height - 1)
-        else:
-            default_bg = float(x)
-
-        logger(f"default_bg {default_bg}")
+        self.gate_dropout_proba = args.gate_dropout_proba
+        self.gate_dropout_sync = args.gate_dropout_sync
 
         ######################################################################
 
+        default_bg = -math.log(caterpillar_height - 1)
         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
 
@@ -561,10 +527,6 @@ class Caterpillar(nn.Module):
             dim_v,
         )
 
-        self.calibrator_G = Calibrator()
-        self.calibrator_rec_V = Calibrator()
-        self.calibrator_rec_K = Calibrator()
-
     def reset_inner_loss(self):
         self.acc_attention = 0
         self.acc_nb = 0
@@ -620,90 +582,81 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
-        self.calibrator_G.update(G.reshape(-1, G.size(-1)))
-
-        # warnings.warn("softmax gating", RuntimeWarning)
+        # Clip the gating to avoid values greater than 1 when several
+        # heads hit the same row
 
-        # G = (
-        # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
-        # ).softmax(dim=2)
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
 
         ######################################################################
-        # The "flashbacks"
 
-        if self.training and self.proba_gate_dropout > 0.0:
-            # This is a better implementation of "flashbacks".
+        def recurrence(G, V, K):
+            # We prepare the arguments for the parallel scan
 
-            # G is NxHxExT where e is the caterpillar's row.
+            A = 1 - G.sum(1)
 
-            warnings.warn("gate dropout", RuntimeWarning)
+            gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+            gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
 
-            kill = (
-                torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
-            ).float()
+            # We start from cached values, which matters in inference
 
-            alpha = G / (1 - self.proba_gate_dropout)
+            init_rec_V = self.rec_V[:, :, t0 - L : t0]
+            init_rec_K = self.rec_K[:, :, t0 - L : t0]
 
-            G = alpha * (1 - kill)
+            # Here there is a trick: Since the stack at position t is
+            # computed by updating that at position t-L, the parallel
+            # scan operates with a period of L. To do so we split the
+            # sequence indexing in two axes, the second of size L, and
+            # run the parallel scan using the first as the sequence index.
 
-        ######################################################################
-        # Clip the gating to avoid values greater than 1 when several
-        # heads hit the same row
+            A = A.unflatten(2, (-1, L))
+            gated_V = gated_V.unflatten(2, (-1, L))
+            gated_K = gated_K.unflatten(2, (-1, L))
 
-        G = G / G.sum(1, keepdim=True).clamp(min=1)
+            next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
+            next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
 
-        ######################################################################
-        # Roll the gating indexes
+            next_V = next_V.flatten(2, 3)
+            next_K = next_K.flatten(2, 3)
 
-        # warnings.warn("rotating barrel", RuntimeWarning)
+            return next_V, next_K
 
-        # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
-        # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
-        # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
-        # G = G.gather(dim=2, index=r_barrel.expand_as(G))
-
-        # We prepare the arguments for the parallel scan
-
-        A = 1 - G.sum(1)
-
-        # warnings.warn("harmonic recurrence", RuntimeWarning)
-        # har = torch.arange(t0, t1, device = G.device).float() + 1
-        # A = har / (har + 1)
-        # G = G / har
+        #################################################################
 
-        gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
-        gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
+        next_V, next_K = recurrence(G, V, K)
 
-        # We start from cached values, which matters in inference
+        if self.training and self.gate_dropout_proba > 0.0:
+            # G is NxHxRxT where r is the caterpillar's row.
 
-        init_rec_V = self.rec_V[:, :, t0 - L : t0]
-        init_rec_K = self.rec_K[:, :, t0 - L : t0]
+            warnings.warn("gate dropout", RuntimeWarning)
 
-        #################################################################
-        # Associative scan
+            if self.gate_dropout_sync:
+                shape_kill = (N, 1, 1)
+            else:
+                shape_kill = (N, H, R)
 
-        # Here there is a trick: Since the stack at position t is
-        # computed by updating that at position t-L, the parallel
-        # scan operates with a period of L. To do so we split the
-        # sequence indexing in two axes, the second of size L, and
-        # run the parallel scan using the first as the sequence index.
+            # Pick a point in each of the NxHxR timeline and set this
+            # entry and the following to 1
+            kill = (
+                torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+                == 0
+            ).cumsum(dim=3)
 
-        A = A.unflatten(2, (-1, L))
-        gated_V = gated_V.unflatten(2, (-1, L))
-        gated_K = gated_K.unflatten(2, (-1, L))
+            # Keep these mask for only some of the NxHxR
+            kill = kill * (
+                torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
+            )
 
-        next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
-        next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
+            # The coefficient to keep are the complementary
+            mask = 1 - kill
 
-        next_V = next_V.flatten(2, 3)
-        next_K = next_K.flatten(2, 3)
+            masked_next_V, masked_next_K = recurrence(G * mask, V, K)
 
-        self.calibrator_rec_V.update(
-            next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2))
-        )
-        self.calibrator_rec_K.update(
-            next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2))
-        )
+            next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
+                1 - self.gate_dropout_proba
+            )
+            next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
+                1 - self.gate_dropout_proba
+            )
 
         self.rec_V[:, :, t0:t1] = next_V
         self.rec_K[:, :, t0:t1] = next_K
@@ -713,8 +666,8 @@ class Caterpillar(nn.Module):
 
         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
 
-        # We build tensors NxHxTxFxL where N is the sample index, H
-        # the head, T the time, F the row in the caterpillar, and L
+        # We build tensors NxHxTxRxL where N is the sample index, H
+        # the head, T the time, R the row in the caterpillar, and L
         # the column in the caterpillar
 
         windowed_V = moving_window(
@@ -728,7 +681,7 @@ class Caterpillar(nn.Module):
         # We have an attention score for each of the RxL values
 
         ar = torch.einsum(
-            "nhtd,nftld->nhtfl",
+            "nhtd,nrtld->nhtrl",
             Q,
             windowed_K,
         ) / math.sqrt(DK)
@@ -768,7 +721,7 @@ class QKVAttention(nn.Module):
         causal=False,
         attention_dropout=0.0,
         logger=print,
-        **kwargs,
+        args=None,
     ):
         super().__init__()
 
@@ -861,7 +814,7 @@ class MyGPT(nn.Module):
         len_max=1e5,
         attention_layer="kvrec",
         logger=print,
-        **kwargs,
+        args=None,
     ):
         super().__init__()
 
@@ -899,7 +852,7 @@ class MyGPT(nn.Module):
                     causal=causal,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args=args,
                 )
             elif attention_layer == "dumbrec":
                 return DumbRec(
@@ -910,7 +863,7 @@ class MyGPT(nn.Module):
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args=args,
                 )
             elif attention_layer == "kvrec":
                 return KVRec(
@@ -921,7 +874,7 @@ class MyGPT(nn.Module):
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args=args,
                 )
             elif attention_layer == "caterpillar":
                 return Caterpillar(
@@ -933,7 +886,7 @@ class MyGPT(nn.Module):
                     caterpillar_height=self.caterpillar_height,
                     attention_dropout=dropout,
                     logger=logger,
-                    **kwargs,
+                    args=args,
                 )
             else:
                 raise ValueError(f"Unknown attention type {attention_layer}.")