Update.
[mygptrnn.git] / mygpt.py
index 040845e..12b3631 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -21,6 +21,8 @@ from torch.nn import functional as F
 
 import ffutils
 
+# from blanket import blanket
+
 # import memload
 
 ######################################################################
@@ -84,6 +86,18 @@ class CacheWrapper(nn.Module):
 ##############################
 
 
+class NaNChecker(nn.Module):
+    def __init__(self, name):
+        super().__init__()
+        self.name = name
+
+    def forward(self, bs):
+        x = bs.x if type(bs) is BracketedSequence else bs
+        assert not x.isnan().any(), f"${self.name} detected NaN"
+        assert not x.isinf().any(), f"${self.name} detected Inf"
+        return bs
+
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
@@ -216,19 +230,9 @@ class DumbRec(nn.Module):
 
         self.w_qw = randw(nb_heads, dim_qk, dim_model)
         self.w_qr = randw(nb_heads, dim_qk, dim_model)
-        # self.w_k = randw(nb_heads, dim_qk, dim_model)
         self.w_v = randw(nb_heads, dim_v, dim_model)
         self.w_o = randw(dim_v * nb_heads, dim_model)
 
-    def reset_inner_loss(self):
-        self.acc_attention = 0
-        self.acc_nb = 0
-
-    def get_inner_loss(self):
-        warnings.warn("l2 regularization", RuntimeWarning)
-        return (self.acc_attention / self.acc_nb).pow(2).sum()
-        # return torch.tensor([0], device=self.w_qw.device)
-
     def forward(self, bs):
         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
 
@@ -236,61 +240,33 @@ class DumbRec(nn.Module):
             self.rec_v = x_q.new_zeros(
                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
             )
-            # self.rec_k = x_q.new_zeros(
-            # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
-            # )
             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
 
-        ######################################################################
-        # Prepare the keys
-
-        k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
-
-        warnings.warn("rotating key barrel", RuntimeWarning)
-        k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
-        t_barrel = torch.arange(t0, t1, device=k_star.device)
-        t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
-        l_barrel = (
-            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
-        ) % k_star.size(0)
-        k_star = k_star[l_barrel, t_barrel]
-
         ######################################################################
         # Compute the recurrent state
 
         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
 
         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
-        # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
 
-        aw = torch.einsum(
-            "nhtd,ltd->nhlt",
-            qw,
-            k_star,
-        ) / math.sqrt(self.w_qw.size(1))
+        aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
+            self.w_qw.size(1)
+        )
 
         aw = aw.softmax(dim=2)  # nhlt
 
-        if self.train:
-            self.acc_attention += aw.sum(dim=(0, 1, 3))
-            self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
-
         aw = F.dropout(aw, self.attention_dropout, self.training)
 
         A = 1 - aw.sum(dim=1)  # nlt
 
         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
-        # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
 
         if t0 == 0:
             V0 = None
-            # K0 = None
         else:
             V0 = self.rec_v[:, :, t0 - 1]
-            # K0 = self.rec_k[:, :, t0 - 1]
 
         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
-        # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
 
         ######################################################################
         # compute the readout
@@ -300,7 +276,6 @@ class DumbRec(nn.Module):
         ar = torch.einsum(
             "nhtd,ld->nhlt",
             qr,
-            # self.rec_k[:, :, t0:t1],
             self.k_star,
         ) / math.sqrt(self.w_qr.size(1))
 
@@ -356,9 +331,9 @@ class KVRec(nn.Module):
         self.acc_nb = 0
 
     def get_inner_loss(self):
-        warnings.warn("l2 regularization", RuntimeWarning)
-        return (self.acc_attention / self.acc_nb).pow(2).sum()
-        return torch.tensor([0], device=self.w_qw.device)
+        warnings.warn("l2 regularization", RuntimeWarning)
+        return (self.acc_attention / self.acc_nb).pow(2).sum()
+        return torch.tensor([0], device=self.w_qw.device)
         # warnings.warn("side regularization", RuntimeWarning)
         # return (
         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
@@ -382,12 +357,12 @@ class KVRec(nn.Module):
 
         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
 
-        warnings.warn("rotating key barrel", RuntimeWarning)
+        warnings.warn("rotating key barrel", RuntimeWarning)
         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
         t_barrel = torch.arange(t0, t1, device=k_star.device)
         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
         l_barrel = (
-            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+            torch.arange(k_star.size(0), device=k_star.device)[:, None]  # + t_barrel
         ) % k_star.size(0)
         k_star = k_star[l_barrel, t_barrel]
 
@@ -500,17 +475,13 @@ class Caterpillar(nn.Module):
         self.caterpillar_height = caterpillar_height
         self.attention_dropout = attention_dropout
 
-        self.gate_dropout_proba = args.gate_dropout_proba
-        self.gate_dropout_sync = args.gate_dropout_sync
-        self.gate_dropout_replace = args.gate_dropout_replace
-
         ######################################################################
 
-        self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0)
+        self.w_G = randw(nb_heads, caterpillar_height, dim_model)
         self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
 
         self.w_K = randw(nb_heads, dim_qk, dim_model)
-        self.w_V = randw(nb_heads, dim_v, dim_model, factor=1)
+        self.w_V = randw(nb_heads, dim_v, dim_model)
         self.w_Q = randw(nb_heads, dim_qk, dim_model)
         self.w_O = randw(dim_v * nb_heads, dim_model)
 
@@ -583,83 +554,32 @@ class Caterpillar(nn.Module):
         # Clip the gating to avoid values greater than 1 when several
         # heads hit the same row
 
-        # G = G / G.sum(1, keepdim=True).clamp(min=1)
-
-        H = (1 - G).log().sum(1, keepdim=True).exp()
+        G = G / G.sum(1, keepdim=True).clamp(min=1)
 
         ######################################################################
 
-        def recurrence(G, V, K):
-            # We prepare the arguments for the parallel scan
-
-            A = H
-
-            gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V)
-            gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K)
-
-            # We start from cached values, which matters in inference
-
-            init_rec_V = self.rec_V[:, :, t0 - L : t0]
-            init_rec_K = self.rec_K[:, :, t0 - L : t0]
-
-            # Here there is a trick: Since the stack at position t is
-            # computed by updating that at position t-L, the parallel
-            # scan operates with a period of L. To do so we split the
-            # sequence indexing in two axes, the second of size L, and
-            # run the parallel scan using the first as the sequence index.
-
-            A = A.unflatten(2, (-1, L))
-            gated_V = gated_V.unflatten(2, (-1, L))
-            gated_K = gated_K.unflatten(2, (-1, L))
-
-            next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
-            next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
-
-            return next_V, next_K
+        A = 1 - G.sum(dim=1)
 
-        #################################################################
+        gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+        gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
 
-        next_V, next_K = recurrence(G, V, K)
+        # We start from cached values, which matters in inference
 
-        if self.training and self.gate_dropout_proba > 0.0:
-            # G is NxHxRxT where r is the caterpillar's row.
+        init_rec_V = self.rec_V[:, :, t0 - L : t0]
+        init_rec_K = self.rec_K[:, :, t0 - L : t0]
 
-            warnings.warn("gate dropout", RuntimeWarning)
+        # Here there is a trick: Since the stack at position t is
+        # computed by updating that at position t-L, the parallel
+        # scan operates with a period of L. To do so we split the
+        # sequence indexing in two axes, the second of size L, and
+        # run the parallel scan using the first as the sequence index.
 
-            if self.gate_dropout_sync:
-                shape_kill = (N, 1, 1)
-            else:
-                shape_kill = (N, H, R)
-
-            # Pick a point in each of the NxHxR timeline and set this
-            # entry and the following to 1
-            kill = (
-                torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
-                == 0
-            ).cumsum(dim=3)
-
-            # Keep these mask for only some of the NxHxR
-            kill = kill * (
-                torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
-            )
-
-            # The coefficient to keep are the complementary
-            mask = 1 - kill
-
-            masked_next_V, masked_next_K = recurrence(G * mask, V, K)
-
-            if self.gate_dropout_replace:
-                next_V = next_V.detach()
-                next_K = next_K.detach()
+        A = A.unflatten(2, (-1, L))
+        gated_V = gated_V.unflatten(2, (-1, L))
+        gated_K = gated_K.unflatten(2, (-1, L))
 
-            warnings.warn("the rescaling is probably a bad idea", RuntimeWarning)
-
-            next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
-                1 - self.gate_dropout_proba
-            )
-            next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
-                1 - self.gate_dropout_proba
-            )
+        next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
+        next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
 
         self.rec_V[:, :, t0:t1] = next_V
         self.rec_K[:, :, t0:t1] = next_K
@@ -669,6 +589,8 @@ class Caterpillar(nn.Module):
 
         Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
 
+        # Q = blanket(Q)
+
         # We build tensors NxHxTxRxL where N is the sample index, H
         # the head, T the time, R the row in the caterpillar, and L
         # the column in the caterpillar
@@ -704,8 +626,6 @@ class Caterpillar(nn.Module):
             windowed_V,
         ).flatten(2)
 
-        # Compute the final output
-
         self.cache_Y[:, t0:t1] = Y @ self.w_O
 
         return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
@@ -722,6 +642,7 @@ class QKVAttention(nn.Module):
         dim_v,
         nb_heads=1,
         causal=False,
+        horizon=None,
         attention_dropout=0.0,
         logger=print,
         args=None,
@@ -732,6 +653,7 @@ class QKVAttention(nn.Module):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
         self.causal = causal
+        self.horizon = horizon
         self.attention_dropout = attention_dropout
         self.record_attention = False
 
@@ -775,6 +697,17 @@ class QKVAttention(nn.Module):
                     torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
                     < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
                 )
+
+                if self.horizon is not None:
+                    self.cache_attzero = torch.logical_or(
+                        self.cache_attzero,
+                        torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
+                        >= torch.arange(x_q.size(1), device=q.device)[
+                            None, None, None, :
+                        ]
+                        + self.horizon,
+                    )
+
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
@@ -821,14 +754,17 @@ class MyGPT(nn.Module):
     ):
         super().__init__()
 
+        self.vocabulary_size = vocabulary_size
+
         assert attention_layer in {
             "mha",
             "dumbrec",
             "kvrec",
             "caterpillar",
+            "attcat",
         }, f"Unknown attention operator {attention_layer}."
 
-        if attention_layer == "caterpillar":
+        if attention_layer == "caterpillar" or attention_layer == "attcat":
             assert nb_lines % caterpillar_height == 0
             self.caterpillar_length = nb_lines // caterpillar_height
             self.caterpillar_height = caterpillar_height
@@ -847,59 +783,99 @@ class MyGPT(nn.Module):
 
         def attlayer():
             if attention_layer == "mha":
-                return QKVAttention(
-                    dim_model=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    causal=causal,
-                    attention_dropout=dropout,
-                    logger=logger,
-                    args=args,
+                return WithResidual(
+                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    QKVAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        causal=causal,
+                        attention_dropout=dropout,
+                        logger=logger,
+                        args=args,
+                    ),
                 )
             elif attention_layer == "dumbrec":
-                return DumbRec(
-                    dim_model=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    nb_lines=nb_lines,
-                    attention_dropout=dropout,
-                    logger=logger,
-                    args=args,
+                return WithResidual(
+                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    DumbRec(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        nb_lines=nb_lines,
+                        attention_dropout=dropout,
+                        logger=logger,
+                        args=args,
+                    ),
                 )
             elif attention_layer == "kvrec":
-                return KVRec(
-                    dim_model=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    nb_lines=nb_lines,
-                    attention_dropout=dropout,
-                    logger=logger,
-                    args=args,
+                return WithResidual(
+                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    KVRec(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        nb_lines=nb_lines,
+                        attention_dropout=dropout,
+                        logger=logger,
+                        args=args,
+                    ),
                 )
             elif attention_layer == "caterpillar":
-                return Caterpillar(
-                    dim_model=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    caterpillar_length=self.caterpillar_length,
-                    caterpillar_height=self.caterpillar_height,
-                    attention_dropout=dropout,
-                    logger=logger,
-                    args=args,
+                return WithResidual(
+                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    Caterpillar(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        caterpillar_length=self.caterpillar_length,
+                        caterpillar_height=self.caterpillar_height,
+                        attention_dropout=dropout,
+                        logger=logger,
+                        args=args,
+                    ),
+                )
+            elif attention_layer == "attcat":
+                return nn.Sequential(
+                    WithResidual(
+                        CacheWrapper(nn.LayerNorm((dim_model,))),
+                        QKVAttention(
+                            dim_model=dim_model,
+                            dim_qk=dim_keys,
+                            dim_v=dim_model // nb_heads,
+                            nb_heads=nb_heads,
+                            causal=causal,
+                            horizon=self.caterpillar_length,
+                            attention_dropout=dropout,
+                            logger=logger,
+                            args=args,
+                        ),
+                    ),
+                    WithResidual(
+                        CacheWrapper(nn.LayerNorm((dim_model,))),
+                        Caterpillar(
+                            dim_model=dim_model,
+                            dim_qk=dim_keys,
+                            dim_v=dim_model // nb_heads,
+                            nb_heads=nb_heads,
+                            caterpillar_length=self.caterpillar_length,
+                            caterpillar_height=self.caterpillar_height,
+                            attention_dropout=dropout,
+                            logger=logger,
+                            args=args,
+                        ),
+                    ),
                 )
             else:
                 raise ValueError(f"Unknown attention type {attention_layer}.")
 
         for b in range(nb_blocks):
             trunk_blocks += [
-                WithResidual(
-                    CacheWrapper(nn.LayerNorm((dim_model,))),
-                    attlayer(),
-                ),
+                attlayer(),
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
@@ -1081,31 +1057,35 @@ if __name__ == "__main__":
     # t = np.arange(dt, 20.0, dt)
     # ax.semilogx(t, np.exp(-t / 5.0))
     # ax.grid()
+    ax.set_yscale("log")
 
     ######################################################################
 
-    for label, model in [
-        # ("nn.Linear", linear),
-        ("mygpy.QKVAttention", qkv),
-        ("mygpt.Caterpillar", caterpillar),
+    for label, model, thickness in [
+        ("nn.Linear", linear, 0.2),
+        ("mygpy.QKVAttention", qkv, 1),
+        ("mygpt.Caterpillar", caterpillar, 2),
     ]:
         y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
 
-        data = []
-        for t in range(y.size(1)):
-            for d in torch.randperm(y.size(2))[:8]:
-                g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0]
-                sg = g.pow(2).sum().item()
-                # sg = 0
-                # for p in model.parameters():
-                # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
-                # sg = sg + g.pow(2).sum().item()
-                data.append([t, sg])
-
-        data = torch.tensor(data)
-        ax.scatter(
-            data[:, 0], data[:, 1], s=1, label=label
-        )  # , color='gray', label='Input')
+        for n, p in [("input", x)] + list(model.named_parameters()):
+            print(f"Processing {model}.{n}")
+            data = []
+            for t in range(y.size(1)):
+                sg = 0
+                for d in torch.randperm(y.size(2))[:8]:
+                    sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
+                assert not sg.isinf().any()
+                assert not sg.isnan().any()
+                data.append([t, sg.sum().item()])
+
+            data = torch.tensor(data)
+            # cx, cy = data[:, 0], data[:, 1]
+            cy = data[:, 1].sort().values
+            cx = torch.linspace(0, 1, cy.size(0))
+            ax.plot(
+                cx, cy, label=label + "." + n, linewidth=thickness
+            )  # , color='gray', label='Input')
 
     # ax.legend(frameon=False, loc="top right")