Update.
[mygptrnn.git] / mygpt.py
index c833012..12b3631 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -86,6 +86,18 @@ class CacheWrapper(nn.Module):
 ##############################
 
 
+class NaNChecker(nn.Module):
+    def __init__(self, name):
+        super().__init__()
+        self.name = name
+
+    def forward(self, bs):
+        x = bs.x if type(bs) is BracketedSequence else bs
+        assert not x.isnan().any(), f"${self.name} detected NaN"
+        assert not x.isinf().any(), f"${self.name} detected Inf"
+        return bs
+
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
@@ -218,19 +230,9 @@ class DumbRec(nn.Module):
 
         self.w_qw = randw(nb_heads, dim_qk, dim_model)
         self.w_qr = randw(nb_heads, dim_qk, dim_model)
-        # self.w_k = randw(nb_heads, dim_qk, dim_model)
         self.w_v = randw(nb_heads, dim_v, dim_model)
         self.w_o = randw(dim_v * nb_heads, dim_model)
 
-    def reset_inner_loss(self):
-        self.acc_attention = 0
-        self.acc_nb = 0
-
-    def get_inner_loss(self):
-        warnings.warn("l2 regularization", RuntimeWarning)
-        return (self.acc_attention / self.acc_nb).pow(2).sum()
-        # return torch.tensor([0], device=self.w_qw.device)
-
     def forward(self, bs):
         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
 
@@ -238,61 +240,33 @@ class DumbRec(nn.Module):
             self.rec_v = x_q.new_zeros(
                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
             )
-            # self.rec_k = x_q.new_zeros(
-            # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
-            # )
             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
 
-        ######################################################################
-        # Prepare the keys
-
-        k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
-
-        warnings.warn("rotating key barrel", RuntimeWarning)
-        k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
-        t_barrel = torch.arange(t0, t1, device=k_star.device)
-        t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
-        l_barrel = (
-            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
-        ) % k_star.size(0)
-        k_star = k_star[l_barrel, t_barrel]
-
         ######################################################################
         # Compute the recurrent state
 
         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
 
         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
-        # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
 
-        aw = torch.einsum(
-            "nhtd,ltd->nhlt",
-            qw,
-            k_star,
-        ) / math.sqrt(self.w_qw.size(1))
+        aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
+            self.w_qw.size(1)
+        )
 
         aw = aw.softmax(dim=2)  # nhlt
 
-        if self.train:
-            self.acc_attention += aw.sum(dim=(0, 1, 3))
-            self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
-
         aw = F.dropout(aw, self.attention_dropout, self.training)
 
         A = 1 - aw.sum(dim=1)  # nlt
 
         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
-        # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
 
         if t0 == 0:
             V0 = None
-            # K0 = None
         else:
             V0 = self.rec_v[:, :, t0 - 1]
-            # K0 = self.rec_k[:, :, t0 - 1]
 
         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
-        # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
 
         ######################################################################
         # compute the readout
@@ -302,7 +276,6 @@ class DumbRec(nn.Module):
         ar = torch.einsum(
             "nhtd,ld->nhlt",
             qr,
-            # self.rec_k[:, :, t0:t1],
             self.k_star,
         ) / math.sqrt(self.w_qr.size(1))
 
@@ -358,9 +331,9 @@ class KVRec(nn.Module):
         self.acc_nb = 0
 
     def get_inner_loss(self):
-        warnings.warn("l2 regularization", RuntimeWarning)
-        return (self.acc_attention / self.acc_nb).pow(2).sum()
-        return torch.tensor([0], device=self.w_qw.device)
+        warnings.warn("l2 regularization", RuntimeWarning)
+        return (self.acc_attention / self.acc_nb).pow(2).sum()
+        return torch.tensor([0], device=self.w_qw.device)
         # warnings.warn("side regularization", RuntimeWarning)
         # return (
         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
@@ -384,12 +357,12 @@ class KVRec(nn.Module):
 
         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
 
-        warnings.warn("rotating key barrel", RuntimeWarning)
+        warnings.warn("rotating key barrel", RuntimeWarning)
         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
         t_barrel = torch.arange(t0, t1, device=k_star.device)
         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
         l_barrel = (
-            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+            torch.arange(k_star.size(0), device=k_star.device)[:, None]  # + t_barrel
         ) % k_star.size(0)
         k_star = k_star[l_barrel, t_barrel]
 
@@ -781,6 +754,8 @@ class MyGPT(nn.Module):
     ):
         super().__init__()
 
+        self.vocabulary_size = vocabulary_size
+
         assert attention_layer in {
             "mha",
             "dumbrec",