Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 15 Feb 2024 22:10:17 +0000 (23:10 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 15 Feb 2024 22:10:17 +0000 (23:10 +0100)
fridge
main.py
maze.py
mygpt.py
stack.py
tasks.py

diff --git a/fridge b/fridge
index 82d2b17..143092c 100644 (file)
--- a/fridge
+++ b/fridge
@@ -316,3 +316,22 @@ class Calibrator:
             if isinstance(m, mygpt.Caterpillar):
                 
 
+
+######################################################################
+
+2024 Feb 13 22:53:52 (from mygpt.py)
+
+        ######################################################################
+        # Prepare the keys
+
+        k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
+
+        warnings.warn("rotating key barrel", RuntimeWarning)
+        k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
+        t_barrel = torch.arange(t0, t1, device=k_star.device)
+        t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
+        l_barrel = (
+            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+        ) % k_star.size(0)
+        k_star = k_star[l_barrel, t_barrel]
+
diff --git a/main.py b/main.py
index d6845e8..6254807 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -11,6 +11,8 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
+# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
 import ffutils
 import mygpt, tasks, problems
 
@@ -51,9 +53,11 @@ parser.add_argument("--force_cpu", type=str2bool, default=False)
 
 ########################################
 
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--physical_batch_size", type=int, default=None)
 
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=None)
 
@@ -89,7 +93,7 @@ parser.add_argument("--attention", type=str, default=None)
 
 parser.add_argument("--memex_proba", type=float, default=0)
 
-parser.add_argument("--memex_nb_epochs", type=float, default=1)
+parser.add_argument("--memex_nb_epochs", type=float, default=None)
 
 parser.add_argument("--dim_model", type=int, default=None)
 
@@ -238,97 +242,97 @@ else:
 default_task_args = {
     "addition": {
         "model": "352M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "byheart": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
     "expr": {
         "model": "352M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 2500000,
         "nb_test_samples": 10000,
     },
     "grid": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "qmlp": {
         "model": "37M",
-        "batch_size": 10,
+        "physical_batch_size": 10,
         "nb_train_samples": 100000,
         "nb_test_samples": 1000,
     },
     "guessop": {
         "model": "352M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 1000000,
         "nb_test_samples": 10000,
     },
     "learnop": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
     "maze": {
         "model": "37M",
-        "batch_size": 5,
+        "physical_batch_size": 5,
         "nb_train_samples": 100000,
         "nb_test_samples": 10000,
     },
     "picoclvr": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "rpl": {
         "model": "352M",
-        "batch_size": 5,
+        "physical_batch_size": 5,
         "nb_train_samples": 2500000,
         "nb_test_samples": 10000,
     },
     "snake": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "stack": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 100000,
         "nb_test_samples": 1000,
     },
     "twotargets": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
     "memory": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 25000,
         "nb_test_samples": 10000,
     },
     "mixing": {
         "model": "37M",
-        "batch_size": 25,
+        "physical_batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "mnist": {
         "model": "37M",
-        "batch_size": 10,
+        "physical_batch_size": 5,
         "nb_train_samples": 60000,
         "nb_test_samples": 10000,
     },
@@ -526,6 +530,90 @@ def get_lr(n_epoch, it):
 ######################################################################
 
 
+def add_memex_v2(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            t = (
+                torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+                .expand(input.size(0), -1)
+                .clone()
+            )
+
+            u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+            caterpillar_length = args.nb_lines // args.caterpillar_height
+            u1 = (
+                u0
+                + torch.randint(
+                    caterpillar_length, (input.size(0), 1), device=input.device
+                )
+                + 1
+            )
+
+            m0 = (t < u0).long()
+            m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+            t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+            m = (t < 0).long()
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+
+            new_input = input[n, t.clamp(min=0)]
+            new_input = (1 - m) * new_input + m * (marker_token)
+
+            yield new_input
+
+        yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            t = (
+                torch.arange(2 * input.size(1), device=input.device)[None, :]
+                .expand(input.size(0), -1)
+                .clone()
+            )
+
+            u = torch.rand(t.size(), device=t.device)
+            u[:, : input.size(1)] = 1.0
+            memex_v3_proba_fragment = 1 / 20
+            u = (u < memex_v3_proba_fragment).long()
+            v = u * torch.randint(input.size(1), u.size())
+            u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
+                :, : input.size(1) - 1
+            ] * input.size(1)
+            u = u.cumsum().clamp(min=0)
+
+            u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+            caterpillar_length = args.nb_lines // args.caterpillar_height
+            u1 = (
+                u0
+                + torch.randint(
+                    caterpillar_length, (input.size(0), 1), device=input.device
+                )
+                + 1
+            )
+
+            m0 = (t < u0).long()
+            m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+            t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+            m = (t < 0).long()
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+
+            new_input = input[n, t.clamp(min=0)]
+            new_input = (1 - m) * new_input + m * (marker_token)
+
+            yield new_input
+
+        yield input
+
+
+######################################################################
+
 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
 
 
@@ -554,7 +642,7 @@ if args.task == "byheart":
         problem=problems.ProblemByHeart(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -565,7 +653,7 @@ elif args.task == "learnop":
         problem=problems.ProblemLearnOperator(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -576,7 +664,7 @@ elif args.task == "guessop":
         problem=problems.ProblemGuessOperator(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -587,7 +675,7 @@ elif args.task == "twotargets":
         problem=problems.ProblemTwoTargets(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -597,7 +685,7 @@ elif args.task == "memory":
         problem=problems.ProblemMemory(len_total=args.memory_len_total),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -609,7 +697,7 @@ elif args.task == "mixing":
         ),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -619,7 +707,7 @@ elif args.task == "addition":
         problem=problems.ProblemAddition(),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         device=device_data,
     )
@@ -628,7 +716,7 @@ elif args.task == "picoclvr":
     task = tasks.PicoCLVR(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         height=args.picoclvr_height,
         width=args.picoclvr_width,
         nb_colors=args.picoclvr_nb_colors,
@@ -642,7 +730,7 @@ elif args.task == "mnist":
     task = tasks.MNIST(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         device=device_data,
     )
 
@@ -650,7 +738,7 @@ elif args.task == "maze":
     task = tasks.Maze(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         height=args.maze_height,
         width=args.maze_width,
         nb_walls=args.maze_nb_walls,
@@ -661,7 +749,7 @@ elif args.task == "snake":
     task = tasks.Snake(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         height=args.snake_height,
         width=args.snake_width,
         nb_colors=args.snake_nb_colors,
@@ -674,7 +762,7 @@ elif args.task == "stack":
     task = tasks.Stack(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         logger=log_string,
         nb_steps=args.stack_nb_steps,
         nb_stacks=args.stack_nb_stacks,
@@ -691,7 +779,7 @@ elif args.task == "expr":
         sequence_length=args.expr_sequence_length,
         operand_max=args.expr_operand_max,
         result_max=args.expr_result_max,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         device=device_data,
     )
 
@@ -699,7 +787,7 @@ elif args.task == "rpl":
     task = tasks.RPL(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         nb_starting_values=args.rpl_nb_starting_values,
         max_input=args.rpl_max_input,
         prog_len=args.rpl_prog_len,
@@ -713,7 +801,7 @@ elif args.task == "grid":
     task = tasks.Grid(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         size=args.grid_size,
         nb_shapes=args.grid_nb_shapes,
         nb_colors=args.grid_nb_colors,
@@ -725,7 +813,7 @@ elif args.task == "qmlp":
     task = tasks.QMLP(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
+        batch_size=args.physical_batch_size,
         result_dir=args.result_dir,
         logger=log_string,
         device=device_data,
@@ -904,60 +992,63 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
 
-    def add_memex(batches, memex_proba):
-        for input in batches:
-            if torch.rand(1).item() < memex_proba:
-                sep = torch.full(
-                    (input.size(0), 1), vocabulary_size - 1, device=input.device
-                )
+    memex_proba = (
+        args.memex_proba
+        if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
+        else 0.0
+    )
 
-                yield torch.cat(
-                    [
-                        input,
-                        sep,
-                        input,
-                    ],
-                    dim=1,
-                )
-            yield input
+    log_string(f"memex_proba {memex_proba}")
 
-    train_batches = add_memex(
-        task.batches(split="train"),
-        args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+    train_batches = add_memex_v2(
+        batches=task.batches(split="train"),
+        memex_proba=memex_proba,
+        marker_token=vocabulary_size - 1,
     )
 
-    for input in train_batches:
-        model.reset_inner_loss()
-        input = input.to(device)
+    def add_none(it):
+        for x in it:
+            yield x
+        yield None
 
-        output = model(mygpt.BracketedSequence(input)).x
-        loss = F.cross_entropy(output.transpose(1, 2), input)
-        inner_loss = model.get_inner_loss()
+    nb_acc_samples = 0
 
-        acc_train_loss += loss.item() * input.size(0)
-        acc_train_inner_loss += inner_loss.item() * input.size(0)
+    for input in add_none(train_batches):
+        if input is not None:
+            model.reset_inner_loss()
+            input = input.to(device)
 
-        nb_train_samples += input.size(0)
-        nb_samples_seen += input.size(0)
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            inner_loss = model.get_inner_loss()
 
-        total_loss = loss + (
-            args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
-        )
+            acc_train_loss += loss.item() * input.size(0)
+            acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+            nb_train_samples += input.size(0)
+            nb_samples_seen += input.size(0)
 
-        it += 1
-        lr = get_lr(n_epoch, it)
-        for param_group in optimizer.param_groups:
-            param_group["lr"] = lr
+            total_loss = loss + (
+                args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+            )
 
-        # log_string(f"learning_rate {lr}")
+            it += 1
+            lr = get_lr(n_epoch, it)
+            for param_group in optimizer.param_groups:
+                param_group["lr"] = lr
 
-        optimizer.zero_grad()
-        total_loss.backward()
-        optimizer.step()
+                # log_string(f"learning_rate {lr}")
 
-        grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+            total_loss.backward()
+            nb_acc_samples += input.size(0)
 
-        loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+        if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
+            assert nb_acc_samples <= args.batch_size
+            optimizer.step()
+            grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+            loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+            optimizer.zero_grad()
+            nb_acc_samples = 0
 
         n_batch += 1
 
diff --git a/maze.py b/maze.py
index 8ac9fce..4953d10 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -231,9 +231,14 @@ def save_image(
             [0, 255, 0],  # start
             [127, 127, 255],  # goal
             [255, 0, 0],  # path
+            [128, 128, 128],  # error
         ]
     )
 
+    def safe_colors(x):
+        m = (x >= 0).long() * (x < colors.size(0) - 1).long()
+        return colors[x * m + (colors.size(0) - 1) * (1 - m)]
+
     mazes = mazes.cpu()
 
     c_mazes = (
@@ -256,7 +261,7 @@ def save_image(
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
         c_predicted_paths = (
-            colors[predicted_paths.reshape(-1)]
+            safe_colors(predicted_paths.reshape(-1))
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
@@ -282,8 +287,6 @@ def save_image(
         -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
     ).clone()
 
-    print(f"{img.size()=} {imgs.size()=}")
-
     for k in range(imgs.size(1)):
         img[
             :,
index c833012..12b3631 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -86,6 +86,18 @@ class CacheWrapper(nn.Module):
 ##############################
 
 
+class NaNChecker(nn.Module):
+    def __init__(self, name):
+        super().__init__()
+        self.name = name
+
+    def forward(self, bs):
+        x = bs.x if type(bs) is BracketedSequence else bs
+        assert not x.isnan().any(), f"${self.name} detected NaN"
+        assert not x.isinf().any(), f"${self.name} detected Inf"
+        return bs
+
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
@@ -218,19 +230,9 @@ class DumbRec(nn.Module):
 
         self.w_qw = randw(nb_heads, dim_qk, dim_model)
         self.w_qr = randw(nb_heads, dim_qk, dim_model)
-        # self.w_k = randw(nb_heads, dim_qk, dim_model)
         self.w_v = randw(nb_heads, dim_v, dim_model)
         self.w_o = randw(dim_v * nb_heads, dim_model)
 
-    def reset_inner_loss(self):
-        self.acc_attention = 0
-        self.acc_nb = 0
-
-    def get_inner_loss(self):
-        warnings.warn("l2 regularization", RuntimeWarning)
-        return (self.acc_attention / self.acc_nb).pow(2).sum()
-        # return torch.tensor([0], device=self.w_qw.device)
-
     def forward(self, bs):
         x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
 
@@ -238,61 +240,33 @@ class DumbRec(nn.Module):
             self.rec_v = x_q.new_zeros(
                 x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
             )
-            # self.rec_k = x_q.new_zeros(
-            # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
-            # )
             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
 
-        ######################################################################
-        # Prepare the keys
-
-        k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
-
-        warnings.warn("rotating key barrel", RuntimeWarning)
-        k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
-        t_barrel = torch.arange(t0, t1, device=k_star.device)
-        t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
-        l_barrel = (
-            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
-        ) % k_star.size(0)
-        k_star = k_star[l_barrel, t_barrel]
-
         ######################################################################
         # Compute the recurrent state
 
         qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
 
         v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
-        # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
 
-        aw = torch.einsum(
-            "nhtd,ltd->nhlt",
-            qw,
-            k_star,
-        ) / math.sqrt(self.w_qw.size(1))
+        aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
+            self.w_qw.size(1)
+        )
 
         aw = aw.softmax(dim=2)  # nhlt
 
-        if self.train:
-            self.acc_attention += aw.sum(dim=(0, 1, 3))
-            self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
-
         aw = F.dropout(aw, self.attention_dropout, self.training)
 
         A = 1 - aw.sum(dim=1)  # nlt
 
         V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
-        # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
 
         if t0 == 0:
             V0 = None
-            # K0 = None
         else:
             V0 = self.rec_v[:, :, t0 - 1]
-            # K0 = self.rec_k[:, :, t0 - 1]
 
         self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
-        # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
 
         ######################################################################
         # compute the readout
@@ -302,7 +276,6 @@ class DumbRec(nn.Module):
         ar = torch.einsum(
             "nhtd,ld->nhlt",
             qr,
-            # self.rec_k[:, :, t0:t1],
             self.k_star,
         ) / math.sqrt(self.w_qr.size(1))
 
@@ -358,9 +331,9 @@ class KVRec(nn.Module):
         self.acc_nb = 0
 
     def get_inner_loss(self):
-        warnings.warn("l2 regularization", RuntimeWarning)
-        return (self.acc_attention / self.acc_nb).pow(2).sum()
-        return torch.tensor([0], device=self.w_qw.device)
+        warnings.warn("l2 regularization", RuntimeWarning)
+        return (self.acc_attention / self.acc_nb).pow(2).sum()
+        return torch.tensor([0], device=self.w_qw.device)
         # warnings.warn("side regularization", RuntimeWarning)
         # return (
         # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
@@ -384,12 +357,12 @@ class KVRec(nn.Module):
 
         k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
 
-        warnings.warn("rotating key barrel", RuntimeWarning)
+        warnings.warn("rotating key barrel", RuntimeWarning)
         k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
         t_barrel = torch.arange(t0, t1, device=k_star.device)
         t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
         l_barrel = (
-            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+            torch.arange(k_star.size(0), device=k_star.device)[:, None]  # + t_barrel
         ) % k_star.size(0)
         k_star = k_star[l_barrel, t_barrel]
 
@@ -781,6 +754,8 @@ class MyGPT(nn.Module):
     ):
         super().__init__()
 
+        self.vocabulary_size = vocabulary_size
+
         assert attention_layer in {
             "mha",
             "dumbrec",
index 543f04e..69a696d 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -25,23 +25,34 @@ def generate_sequences(
     )
 
     for t in range(nb_steps):
-        op = torch.randint(2, (nb,))
-        st = torch.randint(nb_stacks, (nb,))
-        op = op * (stack_counts[k, st] > 0)
-        if values is None:
+        op = torch.randint(2, (nb,))  # what operation (push/pop)
+        st = torch.randint(nb_stacks, (nb,))  # on what stack
+        op = op * (stack_counts[k, st] > 0)  # can only push is stack is empty
+
+        if values is None:  # we can use all the values
             val_push = torch.randint(10**nb_digits, (nb,))
-        else:
+        else:  # values are constrained (e.g. to have train/test values disjoint)
             val_push = values[torch.randint(values.size(0), (nb,))]
-        val_pop = stack[
+
+        val_pop = stack[  # if we were popping, what value would that be?
             k,
             st,
-            (stack_counts[k, st] - 1).clamp(min=0),
+            (stack_counts[k, st] - 1).clamp(min=0),  # deal with empty stack
         ]
+
+        # we always push the value, but it will be lost if we pop
+        # since we will move the count down
         stack[k, st, stack_counts[k, st]] = val_push
         recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+
+        # we increase the stack count only when we actually push
         stack_counts[k[op == 0], st[op == 0]] += 1
         stack_counts[k[op == 1], st[op == 1]] -= 1
+
+        # add the operation number to the sequence, that incude the stack number
         result[:, (1 + nb_digits) * t] = st * 2 + op
+
+        # add the digits to the sequence
         for d in range(nb_digits):
             result[:, (1 + nb_digits) * t + 1 + d] = (
                 (op * val_pop + (1 - op) * val_push) // (10**d)
@@ -57,29 +68,49 @@ def remove_popped_values(seq, nb_stacks, nb_digits):
         seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
 
 
-def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
-    assert seq.size(0) % (1 + nb_digits) == 0
-    s = ""
-    for t in range(seq.size(0) // (1 + nb_digits)):
-        n_op = seq[(1 + nb_digits) * t]
-        if t > 0:
-            s += " "
-        if recorded_stack_counts is not None:
-            s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
-        s += f"POP" if n_op % 2 == 1 else f"PSH"
-        if nb_stacks > 1:
-            s += f"_{n_op//2}"
-        for d in range(nb_digits):
-            if seq[(1 + nb_digits) * t + 1 + d] == -1:
-                s += " ?"
-            else:
-                s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
-    return s
+def seq_to_str(seq, nb_stacks, nb_digits):
+    def n_to_str(n):
+        if n < 0:
+            return "?"
+        elif n < 2 * nb_stacks:
+            s = f"POP" if n % 2 == 1 else f"PSH"
+            if nb_stacks > 1:
+                s += f"_{n//2}"
+                return s
+        elif n < 2 * nb_stacks + 10:
+            return f"{n - 2 * nb_stacks}"
+        else:
+            return "#"
+
+    return " ".join([n_to_str(x.item()) for x in seq])
 
 
 ######################################################################
 
 if __name__ == "__main__":
+    seq, recorded_stack_counts = generate_sequences(
+        nb=3,
+        nb_steps=6,
+        nb_stacks=3,
+        nb_digits=3,
+    )
+
+    sep = torch.full((seq.size(0), 1), seq.max() + 1)
+
+    seq = torch.cat([seq, sep, seq], dim=1)
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+    remove_popped_values(seq, 3, 3)
+
+    print()
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+    exit(0)
+
     nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
     seq, recorded_stack_counts = generate_sequences(
         nb=nb,
@@ -101,6 +132,8 @@ if __name__ == "__main__":
 
     print("-- PREPARED FOR TEST -----------------")
 
+    print("SANITY", seq.size())
+
     remove_popped_values(seq, nb_stacks, nb_digits)
 
     for n in range(min(10, seq.size(0))):
index 727b196..218ff36 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -250,7 +250,13 @@ class PicoCLVR(Task):
 
     # Make a list of strings from a tensor
     def detensorize(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+        def id2token(t):
+            try:
+                return self.id2token[t.item()]
+            except KeyError:
+                return "?"
+
+        return [" ".join([id2token(t) for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its
@@ -888,7 +894,10 @@ class Stack(Task):
         def compute_nb_correct(input):
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+
             ar_mask = (result != input).long()
+            result *= 1 - ar_mask
+
             masked_inplace_autoregression(
                 model,
                 self.batch_size,
@@ -923,10 +932,12 @@ class Stack(Task):
         stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
         ar_mask = (result != input).long()
 
-        # for n in range(result.size(0)):
-        # logger(
-        # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-        # )
+        for n in range(result.size(0)):
+            logger(
+                f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+            )
+
+        result *= 1 - ar_mask
 
         masked_inplace_autoregression(
             model,
@@ -1448,7 +1459,13 @@ class Grid(Task):
 
     # Make a list of strings from a tensor
     def tensor2str(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+        def id2token(t):
+            try:
+                return self.id2token[t.item()]
+            except KeyError:
+                return "?"
+
+        return [" ".join([id2token(t) for t in r]) for r in x]
 
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its