Initial commit
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 09:56:07 +0000 (10:56 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jan 2024 09:56:07 +0000 (10:56 +0100)
17 files changed:
expr.py [new file with mode: 0755]
ffutils.py [new file with mode: 0755]
graph.py [new file with mode: 0755]
grid.py [new file with mode: 0755]
main.py [new file with mode: 0755]
maze.py [new file with mode: 0755]
memload.py [new file with mode: 0755]
mygpt.py [new file with mode: 0755]
picoclvr.py [new file with mode: 0755]
problems.py [new file with mode: 0755]
pscan.py [new file with mode: 0755]
qmlp.py [new file with mode: 0755]
rpl.py [new file with mode: 0755]
snake.py [new file with mode: 0755]
stack.py [new file with mode: 0755]
tasks.py [new file with mode: 0755]
world.py [new file with mode: 0755]

diff --git a/expr.py b/expr.py
new file mode 100755 (executable)
index 0000000..685efd3
--- /dev/null
+++ b/expr.py
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, re
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+
+def random_var(nb_variables=None, variables=None):
+    if variables is None:
+        return chr(ord("A") + torch.randint(nb_variables, (1,)).item())
+    else:
+        l = list(variables)
+        return l[torch.randint(len(l), (1,)).item()]
+
+
+def random_expr(variables, operand_max, budget):
+    if budget <= 5:
+        op = torch.randint(2, (1,)).item()
+        if op == 0 and len(variables) > 0:
+            return random_var(variables=variables)
+        else:
+            return str(torch.randint(operand_max + 1, (1,)).item())
+    else:
+        op = torch.randint(3, (1,)).item()
+        if op == 0:
+            e = random_expr(variables, operand_max, budget - 2)
+            if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"):
+                return "(" + e + ")"
+            else:
+                return e
+        else:
+            b = 2 + torch.randint(budget - 5, (1,)).item()
+            e1 = random_expr(variables, operand_max, b)
+            e2 = random_expr(variables, operand_max, budget - b - 1)
+            if op == 1:
+                return e1 + "+" + e2
+            elif op == 2:
+                return e1 + "*" + e2
+
+
+def generate_program(nb_variables, operand_max, length):
+    s = ""
+    variables = set()
+
+    while len(s) < length:
+        v = random_var(nb_variables=nb_variables)
+        s += v + "=" + random_expr(variables, operand_max, budget=20) + ";"
+        variables.add(v)
+
+    return s, variables
+
+
+def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99):
+    assert nb_variables <= 26
+    sequences = []
+
+    for n in range(nb):
+        # We take length itself half of the time, and uniform between
+        # 1 and length otherwise. The actual length can be slightly
+        # greater
+
+        l = min(length, 1 + torch.randint(length * 2, (1,)).item())
+        result = None
+        while result == None or max(result.values()) > result_max:
+            p, v = generate_program(nb_variables, operand_max, l)
+            v = ", ".join(['"' + v + '": ' + v for v in v])
+            ldict = {}
+            exec(p + "result={" + v + "}", globals(), ldict)
+            result = ldict["result"]
+
+        k = list(result.keys())
+        k.sort()
+        sequences.append(p + " " + "".join([v + ":" + str(result[v]) + ";" for v in k]))
+
+    return sequences
+
+
+def extract_results(seq):
+    f = lambda a: (a[0], -1 if a[1] == "" else int(a[1]))
+    results = [
+        dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)])
+        for s in seq
+    ]
+    return results
+
+
+if __name__ == "__main__":
+    import time
+
+    start_time = time.perf_counter()
+    sequences = generate_sequences(1000, length=40)
+    end_time = time.perf_counter()
+    for s in sequences[:10]:
+        print(s)
+    print(f"{len(sequences) / (end_time - start_time):.02f} samples per second")
+
+    print(extract_results(sequences[:10]))
diff --git a/ffutils.py b/ffutils.py
new file mode 100755 (executable)
index 0000000..23952e5
--- /dev/null
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch
+import sys, contextlib
+
+import torch
+from torch import Tensor
+
+######################################################################
+
+
+@contextlib.contextmanager
+def evaluation(*models):
+    with torch.inference_mode():
+        t = [(m, m.training) for m in models]
+        for m in models:
+            m.train(False)
+        yield
+        for m, u in t:
+            m.train(u)
+
+
+######################################################################
+
+from torch.utils._python_dispatch import TorchDispatchMode
+
+
+def hasNaN(x):
+    if torch.is_tensor(x):
+        return x.numel() > 0 and x.isnan().max()
+    else:
+        try:
+            return any([hasNaN(y) for y in x])
+        except TypeError:
+            return False
+
+
+class NaNDetect(TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args, kwargs=None):
+        kwargs = kwargs or {}
+        res = func(*args, **kwargs)
+
+        if hasNaN(res):
+            raise RuntimeError(
+                f"Function {func}(*{args}, **{kwargs}) " "returned a NaN"
+            )
+        return res
+
+
+######################################################################
+
+
+def exception_hook(exc_type, exc_value, tb):
+    r"""Hacks the call stack message to show all the local variables
+    in case of relevant error, and prints tensors as shape, dtype and
+    device.
+
+    """
+
+    repr_orig = Tensor.__repr__
+    Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
+
+    while tb:
+        print("--------------------------------------------------\n")
+        filename = tb.tb_frame.f_code.co_filename
+        name = tb.tb_frame.f_code.co_name
+        line_no = tb.tb_lineno
+        print(f'  File "{filename}", line {line_no}, in {name}')
+        print(open(filename, "r").readlines()[line_no - 1])
+
+        if exc_type in {RuntimeError, ValueError, IndexError, TypeError}:
+            for n, v in tb.tb_frame.f_locals.items():
+                print(f"  {n} -> {v}")
+
+        print()
+        tb = tb.tb_next
+
+    Tensor.__repr__ = repr_orig
+
+    print(f"{exc_type.__name__}: {exc_value}")
+
+
+def activate_tensorstack():
+    sys.excepthook = exception_hook
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import torch
+
+    def dummy(a, b):
+        print(a @ b)
+
+    def blah(a, b):
+        c = b + b
+        dummy(a, c)
+
+    mmm = torch.randn(2, 3)
+    xxx = torch.randn(3)
+    # print(xxx@mmm)
+    blah(mmm, xxx)
+    blah(xxx, mmm)
diff --git a/graph.py b/graph.py
new file mode 100755 (executable)
index 0000000..07e376a
--- /dev/null
+++ b/graph.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+import cairo
+
+
+######################################################################
+
+
+def save_attention_image(
+    # image to save
+    filename,
+    tokens_input,
+    tokens_output,
+    # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1
+    attention_matrices,
+    # do not draw links with a lesser attention
+    min_link_attention=0,
+    # draw only the strongest links necessary so that their summed
+    # attention is above min_total_attention
+    min_total_attention=None,
+    # draw only the top k links
+    k_top=None,
+    # the purely graphical settings
+    curved=True,
+    pixel_scale=8,
+    token_gap=15,
+    layer_gap=25,
+    y_eps=0.5,
+    padding=10,
+):
+    if k_top is not None:
+        am = []
+        for m in attention_matrices:
+            am.append(m * (m.sort(dim=-1, descending=True).indices < k_top))
+        attention_matrices = am
+
+    if min_total_attention is not None:
+        am = []
+        for m in attention_matrices:
+            s = m.sort(dim=-1)
+            m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
+            b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m)
+            am.append(m * b)
+
+    surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
+
+    ctx = cairo.Context(surface)
+    ctx.scale(pixel_scale, pixel_scale)
+
+    ctx.set_source_rgb(0.0, 0.0, 0.0)
+    ctx.set_font_size(4.0)
+    # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+
+    x, y = 0, 0
+
+    ctx.set_line_width(0.25)
+    for d in range(len(attention_matrices)):
+        at = attention_matrices[d].to("cpu")
+        ni = torch.arange(at.size(0))[:, None].expand_as(at)
+        nj = torch.arange(at.size(1))[None, :].expand_as(at)
+        at = at.flatten()
+        o = at.sort().indices
+        at = at[o]
+        ni = ni.flatten()[o]
+        nj = nj.flatten()[o]
+        for i, j, a in zip(ni, nj, at):
+            if a > 0 and a >= min_link_attention:
+                c = 1 - a.item()
+                ctx.set_source_rgb(c, c, c)
+                ax, ay = j * token_gap, y - y_eps
+                ctx.move_to(ax, ay)
+                dx, dy = i * token_gap, y - layer_gap + y_eps
+                if curved:
+                    bx, by = ax, ay - layer_gap * 0.5
+                    cx, cy = dx, dy + layer_gap * 0.5
+                    ctx.curve_to(bx, by, cx, cy, dx, dy)
+                else:
+                    ctx.line_to(dx, dy)
+                ctx.stroke()
+        y -= layer_gap
+
+    for d in range(0, len(attention_matrices) + 1):
+        n = (
+            attention_matrices[0].size(-1)
+            if d == 0
+            else attention_matrices[d - 1].size(-2)
+        )
+        for n in range(n):
+            xc, yc = n * token_gap, -d * layer_gap
+            ctx.set_source_rgb(1.0, 1.0, 1.0)
+            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+            ctx.fill()
+            ctx.set_source_rgb(0.0, 0.0, 0.0)
+            ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
+            ctx.fill()
+
+    ctx.set_source_rgb(0.0, 0.0, 0.0)
+
+    for k, t in enumerate(tokens_input):
+        s = str(t)
+        (
+            x_bearing,
+            y_bearing,
+            width_t,
+            height_t,
+            x_advance,
+            y_advance,
+        ) = ctx.text_extents(s)
+        ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5)
+        ctx.show_text(s)
+
+    for k, t in enumerate(tokens_output):
+        s = str(t)
+        (
+            x_bearing,
+            y_bearing,
+            width_t,
+            height_t,
+            x_advance,
+            y_advance,
+        ) = ctx.text_extents(s)
+        ctx.move_to(
+            k * token_gap - width_t / 2,
+            -token_gap / 5 - len(attention_matrices) * layer_gap,
+        )
+        ctx.show_text(s)
+
+    x, y, width, height = surface.ink_extents()
+    x -= padding
+    y -= padding
+    width += 2 * padding
+    height += 2 * padding
+    pdf_surface = cairo.PDFSurface(filename, width, height)
+    ctx_pdf = cairo.Context(pdf_surface)
+    ctx_pdf.set_source_surface(surface, -x, -y)
+    ctx_pdf.paint()
+    pdf_surface.finish()
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import mygpt
+
+    tokens_output = ["<wat>", "-", 3, 4, "<end>"]
+    tokens_input = [""] + tokens_output[:-1]
+
+    vocabulary_size = 3
+    x = torch.randint(vocabulary_size, (1, len(tokens_input)))
+
+    model = mygpt.MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=4,
+        dim_keys=2,
+        dim_hidden=2,
+        nb_heads=2,
+        nb_blocks=5,
+        dropout=0.1,
+        causal=True,
+    )
+
+    model.eval()
+    model.record_attention()
+
+    y1 = model(mygpt.BracketedSequence(x)).x
+
+    attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
+
+    # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]]
+
+    save_attention_image(
+        "attention.pdf",
+        tokens_input,
+        tokens_output,
+        attention_matrices,
+        # k_top=2,
+        min_total_attention=0.9,
+    )
diff --git a/grid.py b/grid.py
new file mode 100755 (executable)
index 0000000..268f4ee
--- /dev/null
+++ b/grid.py
@@ -0,0 +1,236 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+import torch, torchvision
+import torch.nn.functional as F
+
+name_shapes = ["A", "B", "C", "D", "E", "F"]
+
+name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
+
+######################################################################
+
+
+class GridFactory:
+    def __init__(
+        self,
+        size=6,
+        max_nb_items=4,
+        max_nb_transformations=3,
+        nb_questions=4,
+    ):
+        assert size % 2 == 0
+        self.size = size
+        self.max_nb_items = max_nb_items
+        self.max_nb_transformations = max_nb_transformations
+        self.nb_questions = nb_questions
+
+    def generate_scene(self):
+        nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
+        col = torch.full((self.size * self.size,), -1)
+        shp = torch.full((self.size * self.size,), -1)
+        a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
+        col[:nb_items] = a % len(name_colors)
+        shp[:nb_items] = a // len(name_colors)
+        i = torch.randperm(self.size * self.size)
+        col = col[i]
+        shp = shp[i]
+        return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
+
+    def random_transformations(self, scene):
+        col, shp = scene
+
+        descriptions = []
+        nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+        transformations = torch.randint(5, (nb_transformations,))
+
+        for t in transformations:
+            if t == 0:
+                col, shp = col.flip(0), shp.flip(0)
+                descriptions += ["<chg> vertical flip"]
+            elif t == 1:
+                col, shp = col.flip(1), shp.flip(1)
+                descriptions += ["<chg> horizontal flip"]
+            elif t == 2:
+                col, shp = col.flip(0).t(), shp.flip(0).t()
+                descriptions += ["<chg> rotate 90 degrees"]
+            elif t == 3:
+                col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+                descriptions += ["<chg> rotate 180 degrees"]
+            elif t == 4:
+                col, shp = col.flip(1).t(), shp.flip(1).t()
+                descriptions += ["<chg> rotate 270 degrees"]
+
+            col, shp = col.contiguous(), shp.contiguous()
+
+        return (col, shp), descriptions
+
+    def print_scene(self, scene):
+        col, shp = scene
+
+        # for i in range(self.size):
+        # for j in range(self.size):
+        # if col[i,j] >= 0:
+        # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+
+        for i in range(self.size):
+            for j in range(self.size):
+                if col[i, j] >= 0:
+                    print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+                elif j == 0:
+                    print(" +", end="")
+                else:
+                    print("-+", end="")
+                if j < self.size - 1:
+                    print("--", end="")
+                else:
+                    print("")
+            if i < self.size - 1:
+                for j in range(self.size - 1):
+                    print(" |  ", end="")
+                print(" |")
+
+    def grid_positions(self, scene):
+        col, shp = scene
+
+        properties = []
+
+        for i in range(self.size):
+            for j in range(self.size):
+                if col[i, j] >= 0:
+                    n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+                    properties += [f"a {n} at {i} {j}"]
+
+        return properties
+
+    def all_properties(self, scene):
+        col, shp = scene
+
+        properties = []
+
+        for i1 in range(self.size):
+            for j1 in range(self.size):
+                if col[i1, j1] >= 0:
+                    n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+                    properties += [f"there is a {n1}"]
+                    if i1 < self.size // 2:
+                        properties += [f"a {n1} is in the top half"]
+                    if i1 >= self.size // 2:
+                        properties += [f"a {n1} is in the bottom half"]
+                    if j1 < self.size // 2:
+                        properties += [f"a {n1} is in the left half"]
+                    if j1 >= self.size // 2:
+                        properties += [f"a {n1} is in the right half"]
+                    for i2 in range(self.size):
+                        for j2 in range(self.size):
+                            if col[i2, j2] >= 0:
+                                n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+                                if i1 > i2:
+                                    properties += [f"a {n1} is below a {n2}"]
+                                if i1 < i2:
+                                    properties += [f"a {n1} is above a {n2}"]
+                                if j1 > j2:
+                                    properties += [f"a {n1} is right of a {n2}"]
+                                if j1 < j2:
+                                    properties += [f"a {n1} is left of a {n2}"]
+                                if abs(i1 - i2) + abs(j1 - j2) == 1:
+                                    properties += [f"a {n1} is next to a {n2}"]
+
+        return properties
+
+    def generate_scene_and_questions(self):
+        while True:
+            while True:
+                start_scene = self.generate_scene()
+                scene, transformations = self.random_transformations(start_scene)
+                true = self.all_properties(scene)
+                if len(true) >= self.nb_questions:
+                    break
+
+            for a in range(10):
+                col, shp = scene
+                col, shp = col.view(-1), shp.view(-1)
+                p = torch.randperm(col.size(0))
+                col, shp = col[p], shp[p]
+                other_scene = (
+                    col.view(self.size, self.size),
+                    shp.view(self.size, self.size),
+                )
+
+                false = self.all_properties(other_scene)
+
+                # We sometime add properties from a totally different
+                # scene to have negative "there is a xxx xxx"
+                # properties
+                if torch.rand(1).item() < 0.2:
+                    other_scene = self.generate_scene()
+                    false += self.all_properties(other_scene)
+
+                false = list(set(false) - set(true))
+                if len(false) >= self.nb_questions:
+                    break
+
+            if a < 10:
+                break
+
+        true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
+        false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
+        true = ["<prop> " + q + " <ans> true" for q in true]
+        false = ["<prop> " + q + " <ans> false" for q in false]
+
+        union = true + false
+        questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
+
+        result = " ".join(
+            ["<obj> " + x for x in self.grid_positions(start_scene)]
+            + transformations
+            + questions
+        )
+
+        return start_scene, scene, result
+
+    def generate_samples(self, nb, progress_bar=None):
+        result = []
+
+        r = range(nb)
+        if progress_bar is not None:
+            r = progress_bar(r)
+
+        for _ in r:
+            result.append(self.generate_scene_and_questions()[2])
+
+        return result
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import time
+
+    grid_factory = GridFactory()
+
+    # start_time = time.perf_counter()
+    # samples = grid_factory.generate_samples(10000)
+    # end_time = time.perf_counter()
+    # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+
+    start_scene, scene, questions = grid_factory.generate_scene_and_questions()
+    print()
+    print("-- Original scene -----------------------------")
+    print()
+    grid_factory.print_scene(start_scene)
+    print()
+    print("-- Transformed scene --------------------------")
+    print()
+    grid_factory.print_scene(scene)
+    print()
+    print("-- Sequence -----------------------------------")
+    print()
+    print(questions)
+
+######################################################################
diff --git a/main.py b/main.py
new file mode 100755 (executable)
index 0000000..df46652
--- /dev/null
+++ b/main.py
@@ -0,0 +1,912 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, argparse, time, tqdm, os, datetime, warnings
+
+import torch, torchvision
+from torch import nn
+from torch.nn import functional as F
+
+import ffutils
+import mygpt, tasks, problems
+
+######################################################################
+
+if torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description="An implementation of GPT with cache.",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+
+parser.add_argument(
+    "--task",
+    type=str,
+    default="twotargets",
+    help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+)
+
+parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
+
+parser.add_argument("--result_dir", type=str, default=None)
+
+parser.add_argument("--seed", type=int, default=0)
+
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+
+########################################
+
+parser.add_argument("--nb_epochs", type=int, default=50)
+
+parser.add_argument("--batch_size", type=int, default=None)
+
+parser.add_argument("--nb_train_samples", type=int, default=None)
+
+parser.add_argument("--nb_test_samples", type=int, default=None)
+
+parser.add_argument("--optim", type=str, default="adam")
+
+########################################
+
+parser.add_argument("--nb_warmup_iter", type=int, default=100)
+
+parser.add_argument("--nb_decay_iter", type=int, default=5000)
+
+parser.add_argument("--learning_rate", type=float, default=6e-4)
+
+parser.add_argument("--min_learning_rate", type=float, default=6e-5)
+
+########################################
+
+parser.add_argument("--model", type=str, default=None)
+
+parser.add_argument("--attention", type=str, default=None)
+
+parser.add_argument("--dim_model", type=int, default=None)
+
+parser.add_argument("--dim_keys", type=int, default=None)
+
+parser.add_argument("--dim_hidden", type=int, default=None)
+
+parser.add_argument("--nb_heads", type=int, default=None)
+
+parser.add_argument("--nb_lines", type=int, default=None)
+
+parser.add_argument("--caterpillar_height", type=int, default=None)
+
+parser.add_argument("--rho", type=float, default=0.0)
+
+parser.add_argument("--dim_rec_v", type=int, default=None)
+
+parser.add_argument("--nb_blocks", type=int, default=None)
+
+parser.add_argument("--dropout", type=float, default=0.1)
+
+########################################
+
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
+
+parser.add_argument("--overwrite_results", action="store_true", default=False)
+
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
+
+##############################
+# rpl options
+
+parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
+
+parser.add_argument("--rpl_max_input", type=int, default=9)
+
+parser.add_argument("--rpl_prog_len", type=int, default=8)
+
+parser.add_argument("--rpl_nb_runs", type=int, default=5)
+
+parser.add_argument("--rpl_no_prog", action="store_true", default=False)
+
+##############################
+# grid options
+
+parser.add_argument("--grid_size", type=int, default=6)
+
+##############################
+# picoclvr options
+
+parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
+
+parser.add_argument("--picoclvr_height", type=int, default=12)
+
+parser.add_argument("--picoclvr_width", type=int, default=16)
+
+parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
+
+##############################
+# Maze options
+
+parser.add_argument("--maze_height", type=int, default=13)
+
+parser.add_argument("--maze_width", type=int, default=21)
+
+parser.add_argument("--maze_nb_walls", type=int, default=15)
+
+##############################
+# Snake options
+
+parser.add_argument("--snake_height", type=int, default=9)
+
+parser.add_argument("--snake_width", type=int, default=12)
+
+parser.add_argument("--snake_nb_colors", type=int, default=5)
+
+parser.add_argument("--snake_length", type=int, default=200)
+
+##############################
+# Stack options
+
+parser.add_argument("--stack_nb_steps", type=int, default=100)
+
+parser.add_argument("--stack_nb_stacks", type=int, default=3)
+
+parser.add_argument("--stack_nb_digits", type=int, default=3)
+
+parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75)
+
+##############################
+# Expr options
+
+parser.add_argument("--expr_nb_variables", type=int, default=5)
+
+parser.add_argument("--expr_sequence_length", type=int, default=40)
+
+parser.add_argument("--expr_operand_max", type=int, default=9)
+
+parser.add_argument("--expr_result_max", type=int, default=99)
+
+parser.add_argument("--expr_input_file", type=str, default=None)
+
+##############################
+# Memory
+
+parser.add_argument("--memory_len_total", type=int, default=32)
+
+##############################
+# Mixing
+
+parser.add_argument("--mixing_hard", action="store_true", default=False)
+
+parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
+
+######################################################################
+
+args = parser.parse_args()
+
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+if args.result_dir is None:
+    args.result_dir = f"results_{args.task}_{args.model}"
+
+######################################################################
+
+default_task_args = {
+    "addition": {
+        "model": "352M",
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "byheart": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 50000,
+        "nb_test_samples": 10000,
+    },
+    "expr": {
+        "model": "352M",
+        "batch_size": 25,
+        "nb_train_samples": 2500000,
+        "nb_test_samples": 10000,
+    },
+    "grid": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "qmlp": {
+        "model": "37M",
+        "batch_size": 10,
+        "nb_train_samples": 100000,
+        "nb_test_samples": 1000,
+    },
+    "guessop": {
+        "model": "352M",
+        "batch_size": 25,
+        "nb_train_samples": 1000000,
+        "nb_test_samples": 10000,
+    },
+    "learnop": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 50000,
+        "nb_test_samples": 10000,
+    },
+    "maze": {
+        "model": "37M",
+        "batch_size": 5,
+        "nb_train_samples": 100000,
+        "nb_test_samples": 10000,
+    },
+    "picoclvr": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "rpl": {
+        "model": "352M",
+        "batch_size": 5,
+        "nb_train_samples": 2500000,
+        "nb_test_samples": 10000,
+    },
+    "snake": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "stack": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 100000,
+        "nb_test_samples": 1000,
+    },
+    "twotargets": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 50000,
+        "nb_test_samples": 10000,
+    },
+    "memory": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 25000,
+        "nb_test_samples": 10000,
+    },
+    "mixing": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "mnist": {
+        "model": "37M",
+        "batch_size": 10,
+        "nb_train_samples": 60000,
+        "nb_test_samples": 10000,
+    },
+}
+
+if args.task in default_task_args:
+    for k, v in default_task_args[args.task].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+
+######################################################################
+
+default_model_args = {
+    "17K": {
+        "attention": "mha",
+        "dim_model": 32,
+        "dim_keys": 32,
+        "dim_hidden": 32,
+        "nb_heads": 2,
+        "dim_rec_v": 16,
+        "nb_blocks": 2,
+    },
+    "17K-C": {
+        "attention": "caterpillar",
+        "dim_model": 32,
+        "dim_keys": 32,
+        "dim_hidden": 32,
+        "nb_heads": 2,
+        "nb_lines": 16,
+        "caterpillar_height": 4,
+        "dim_rec_v": 16,
+        "nb_blocks": 2,
+    },
+    "4M": {
+        "attention": "mha",
+        "dim_model": 256,
+        "dim_keys": 32,
+        "dim_hidden": 1024,
+        "nb_heads": 4,
+        "dim_rec_v": 64,
+        "nb_blocks": 6,
+    },
+    "4M-C": {
+        "attention": "caterpillar",
+        "dim_model": 256,
+        "dim_keys": 32,
+        "dim_hidden": 1024,
+        "nb_heads": 4,
+        "nb_lines": 32,
+        "caterpillar_height": 4,
+        "dim_rec_v": 64,  # dim_model / nb_heads
+        "nb_blocks": 6,
+    },
+    "37M": {
+        "dim_model": 512,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "dim_rec_v": 64,
+        "nb_blocks": 12,
+    },
+    "37M-C": {
+        "attention": "caterpillar",
+        "dim_model": 512,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_lines": 256,
+        "caterpillar_height": 32,
+        "dim_rec_v": 64,
+        "nb_blocks": 12,
+    },
+    "122M": {
+        "attention": "mha",
+        "dim_model": 768,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "dim_rec_v": 96,
+        "nb_blocks": 24,
+    },
+    "122M-C": {
+        "attention": "caterpillar",
+        "dim_model": 768,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_lines": 128,
+        "dim_rec_v": 96,
+        "nb_blocks": 24,
+    },
+    "352M": {
+        "attention": "mha",
+        "dim_model": 1024,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "dim_rec_v": 128,
+        "nb_blocks": 48,
+    },
+    "352M-C": {
+        "attention": "caterpillar",
+        "dim_model": 1024,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_lines": 128,
+        "dim_rec_v": 128,
+        "nb_blocks": 48,
+    },
+}
+
+if args.model in default_model_args:
+    for k, v in default_model_args[args.model].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+else:
+    raise ValueError(f"Unknown model {args.model}")
+
+######################################################################
+
+try:
+    os.mkdir(args.result_dir)
+except FileExistsError:
+    if not args.overwrite_results:
+        print(f"result directory {args.result_dir} already exists")
+        exit(1)
+
+log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
+
+if args.seed >= 0:
+    # torch.backends.cudnn.deterministic = True
+    # torch.backends.cudnn.benchmark = False
+    # torch.use_deterministic_algorithms(True)
+    torch.manual_seed(args.seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(args.seed)
+
+######################################################################
+
+
+def log_string(s):
+    t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + "\n")
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+
+with os.popen("sha256sum *.py") as f:
+    for l in f:
+        log_string(f"sha256sum {l.strip()}")
+
+now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+log_string(f"argv {' '.join(sys.argv)}")
+
+for n in vars(args):
+    log_string(f"args.{n} {getattr(args, n)}")
+
+
+######################################################################
+
+# from nanoGPT
+
+
+def get_lr(it):
+    # 1) linear warmup for warmup_iter steps
+    if it < args.nb_warmup_iter:
+        return args.learning_rate * it / args.nb_warmup_iter
+    # 2) if it > nb_decay_iter, return min learning rate
+    if it > args.nb_decay_iter:
+        return args.min_learning_rate
+    # 3) in between, use cosine decay down to min learning rate
+    decay_ratio = (it - args.nb_warmup_iter) / (
+        args.nb_decay_iter - args.nb_warmup_iter
+    )
+    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
+    return args.min_learning_rate + coeff * (
+        args.learning_rate - args.min_learning_rate
+    )
+
+
+######################################################################
+
+
+def picoclvr_pruner_horizontal_green(p):
+    return not ("green" in p and ("left" in p or "right" in p))
+
+
+picoclvr_pruner_train = (
+    picoclvr_pruner_horizontal_green
+    if args.picocvlr_prune_properties in {"train+eval"}
+    else None
+)
+
+picoclvr_pruner_eval = (
+    (lambda p: not picoclvr_pruner_horizontal_green(p))
+    if args.picocvlr_prune_properties in {"train+eval", "eval"}
+    else None
+)
+
+######################################################################
+
+device_data = device
+
+if args.task == "byheart":
+    task = tasks.SandBox(
+        problem=problems.ProblemByHeart(),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+    args.max_percents_of_test_in_train = -1
+
+elif args.task == "learnop":
+    task = tasks.SandBox(
+        problem=problems.ProblemLearnOperator(),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+
+elif args.task == "guessop":
+    task = tasks.SandBox(
+        problem=problems.ProblemGuessOperator(),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+
+elif args.task == "twotargets":
+    task = tasks.SandBox(
+        problem=problems.ProblemTwoTargets(),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+elif args.task == "memory":
+    task = tasks.SandBox(
+        problem=problems.ProblemMemory(len_total=args.memory_len_total),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+elif args.task == "mixing":
+    task = tasks.SandBox(
+        problem=problems.ProblemMixing(
+            hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
+        ),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+elif args.task == "addition":
+    task = tasks.SandBox(
+        problem=problems.ProblemAddition(),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+elif args.task == "picoclvr":
+    task = tasks.PicoCLVR(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        height=args.picoclvr_height,
+        width=args.picoclvr_width,
+        nb_colors=args.picoclvr_nb_colors,
+        logger=log_string,
+        device=device_data,
+        pruner_train=picoclvr_pruner_train,
+        pruner_eval=picoclvr_pruner_eval,
+    )
+
+elif args.task == "mnist":
+    task = tasks.MNIST(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        device=device_data,
+    )
+
+elif args.task == "maze":
+    task = tasks.Maze(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        height=args.maze_height,
+        width=args.maze_width,
+        nb_walls=args.maze_nb_walls,
+        device=device_data,
+    )
+
+elif args.task == "snake":
+    task = tasks.Snake(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        height=args.snake_height,
+        width=args.snake_width,
+        nb_colors=args.snake_nb_colors,
+        length=args.snake_length,
+        prompt_length=args.snake_length // 2,
+        device=device_data,
+    )
+
+elif args.task == "stack":
+    task = tasks.Stack(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        nb_steps=args.stack_nb_steps,
+        nb_stacks=args.stack_nb_stacks,
+        nb_digits=args.stack_nb_digits,
+        fraction_values_for_train=args.stack_fraction_values_for_train,
+        device=device_data,
+    )
+
+elif args.task == "expr":
+    task = tasks.Expr(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        nb_variables=args.expr_nb_variables,
+        sequence_length=args.expr_sequence_length,
+        operand_max=args.expr_operand_max,
+        result_max=args.expr_result_max,
+        batch_size=args.batch_size,
+        device=device_data,
+    )
+
+elif args.task == "rpl":
+    task = tasks.RPL(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        nb_starting_values=args.rpl_nb_starting_values,
+        max_input=args.rpl_max_input,
+        prog_len=args.rpl_prog_len,
+        nb_runs=args.rpl_nb_runs,
+        no_prog=args.rpl_no_prog,
+        logger=log_string,
+        device=device_data,
+    )
+
+elif args.task == "grid":
+    task = tasks.Grid(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        size=args.grid_size,
+        logger=log_string,
+        device=device_data,
+    )
+
+elif args.task == "qmlp":
+    task = tasks.QMLP(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        result_dir=args.result_dir,
+        logger=log_string,
+        device=device_data,
+    )
+
+else:
+    raise ValueError(f"Unknown task {args.task}")
+
+######################################################################
+
+log_string(f"device {device}")
+
+vocabulary_size = task.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
+##############################
+
+model = mygpt.MyGPT(
+    vocabulary_size=vocabulary_size,
+    dim_model=args.dim_model,
+    dim_keys=args.dim_keys,
+    dim_hidden=args.dim_hidden,
+    nb_heads=args.nb_heads,
+    nb_lines=args.nb_lines,
+    caterpillar_height=args.caterpillar_height,
+    dim_rec_v=args.dim_rec_v,
+    nb_blocks=args.nb_blocks,
+    causal=True,
+    dropout=args.dropout,
+    attention_layer=args.attention,
+)
+
+model.to(device)
+
+nb_parameters = sum(p.numel() for p in model.parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+######################################################################
+
+nb_epochs_finished = 0
+
+if args.no_checkpoint:
+    log_string(f"not trying to load checkpoint.")
+
+else:
+    try:
+        checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+        checkpoint = torch.load(checkpoint_name)
+        nb_epochs_finished = checkpoint["nb_epochs_finished"]
+        model.load_state_dict(checkpoint["model_state"])
+        torch.set_rng_state(checkpoint["rng_state"])
+        if torch.cuda.is_available():
+            torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+
+        log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
+
+    except FileNotFoundError:
+        log_string("starting from scratch.")
+
+    except:
+        log_string("error when loading the checkpoint.")
+        exit(1)
+
+######################################################################
+
+if args.task == "expr" and args.expr_input_file is not None:
+    task.produce_results(
+        n_epoch=nb_epochs_finished,
+        model=model,
+        result_dir=args.result_dir,
+        logger=log_string,
+        deterministic_synthesis=args.deterministic_synthesis,
+        input_file=args.expr_input_file,
+    )
+
+    exit(0)
+
+######################################################################
+
+nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
+
+# Compute the entropy of the training tokens
+
+token_count = 0
+for input in task.batches(split="train"):
+    token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+    def subsets_as_tuples(batches, cs):
+        s = set()
+        for batch in batches:
+            for x in batch:
+                s.add(tuple([v.item() for v in x]))
+                if len(s) == cs:
+                    yield s
+                    s = set()
+        yield s
+
+    nb_test, nb_in_train = 0, 0
+    for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
+        in_train = set()
+        for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
+            in_train.update(test_subset.intersection(train_subset))
+        nb_in_train += len(in_train)
+        nb_test += len(test_subset)
+
+    log_string(
+        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
+    )
+
+    assert (
+        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+##############################
+
+nb_samples_seen = 0
+
+if nb_epochs_finished >= nb_epochs:
+    task.produce_results(
+        n_epoch=nb_epochs_finished,
+        model=model,
+        result_dir=args.result_dir,
+        logger=log_string,
+        deterministic_synthesis=args.deterministic_synthesis,
+    )
+
+time_pred_result = None
+
+it = 0
+
+for n_epoch in range(nb_epochs_finished, nb_epochs):
+    if args.optim == "sgd":
+        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
+    elif args.optim == "adam":
+        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    elif args.optim == "adamw":
+        optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
+    else:
+        raise ValueError(f"Unknown optimizer {args.optim}.")
+
+    model.train()
+
+    nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
+
+    for input in task.batches(split="train"):
+        model.reset_inner_loss()
+        input = input.to(device)
+
+        output = model(mygpt.BracketedSequence(input)).x
+        loss = F.cross_entropy(output.transpose(1, 2), input)
+        inner_loss = model.get_inner_loss()
+
+        acc_train_loss += loss.item() * input.size(0)
+        acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+        nb_train_samples += input.size(0)
+        nb_samples_seen += input.size(0)
+
+        total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+
+        it += 1
+        lr = get_lr(it)
+        for param_group in optimizer.param_groups:
+            param_group["lr"] = lr
+
+        # log_string(f"learning_rate {lr}")
+
+        optimizer.zero_grad()
+        total_loss.backward()
+        optimizer.step()
+
+    with torch.autograd.no_grad():
+        model.eval()
+
+        nb_test_samples, acc_test_loss = 0, 0.0
+
+        for input in task.batches(split="test"):
+            input = input.to(device)
+
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_test_loss += loss.item() * input.size(0)
+            nb_test_samples += input.size(0)
+
+        log_string(
+            f"loss {n_epoch} train_loss {acc_train_loss/nb_train_samples} train_inner_loss {acc_train_inner_loss/nb_train_samples} test_prediction {acc_test_loss/nb_test_samples}"
+        )
+
+        task.produce_results(
+            n_epoch=n_epoch,
+            model=model,
+            result_dir=args.result_dir,
+            logger=log_string,
+            deterministic_synthesis=args.deterministic_synthesis,
+        )
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+        log_string(
+            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+        )
+
+        time_current_result = datetime.datetime.now()
+        if time_pred_result is not None:
+            log_string(
+                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+            )
+        time_pred_result = time_current_result
+
+    checkpoint = {
+        "nb_epochs_finished": n_epoch + 1,
+        "model_state": model.state_dict(),
+        "rng_state": torch.get_rng_state(),
+    }
+
+    if torch.cuda.is_available():
+        checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
+
+    checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
+    torch.save(checkpoint, checkpoint_name)
+    log_string(f"saved checkpoint {checkpoint_name}")
+
+######################################################################
diff --git a/maze.py b/maze.py
new file mode 100755 (executable)
index 0000000..8ac9fce
--- /dev/null
+++ b/maze.py
@@ -0,0 +1,309 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
+
+
+def create_maze(h=11, w=17, nb_walls=8):
+    assert h % 2 == 1 and w % 2 == 1
+
+    a, k = 0, 0
+
+    while k < nb_walls:
+        while True:
+            if a == 0:
+                m = torch.zeros(h, w, dtype=torch.int64)
+                m[0, :] = 1
+                m[-1, :] = 1
+                m[:, 0] = 1
+                m[:, -1] = 1
+
+            r = torch.rand(4)
+
+            if r[0] <= 0.5:
+                i1, i2, j = (
+                    int((r[1] * h).item()),
+                    int((r[2] * h).item()),
+                    int((r[3] * w).item()),
+                )
+                i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
+                i1, i2 = min(i1, i2), max(i1, i2)
+                if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
+                    m[i1 : i2 + 1, j] = 1
+                    break
+            else:
+                i, j1, j2 = (
+                    int((r[1] * h).item()),
+                    int((r[2] * w).item()),
+                    int((r[3] * w).item()),
+                )
+                i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
+                j1, j2 = min(j1, j2), max(j1, j2)
+                if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
+                    m[i, j1 : j2 + 1] = 1
+                    break
+            a += 1
+
+            if a > 10 * nb_walls:
+                a, k = 0, 0
+
+        k += 1
+
+    return m
+
+
+######################################################################
+
+
+def compute_distance(walls, goal_i, goal_j):
+    max_length = walls.numel()
+    dist = torch.full_like(walls, max_length)
+
+    dist[goal_i, goal_j] = 0
+    pred_dist = torch.empty_like(dist)
+
+    while True:
+        pred_dist.copy_(dist)
+        d = (
+            torch.cat(
+                (
+                    dist[None, 1:-1, 0:-2],
+                    dist[None, 2:, 1:-1],
+                    dist[None, 1:-1, 2:],
+                    dist[None, 0:-2, 1:-1],
+                ),
+                0,
+            ).min(dim=0)[0]
+            + 1
+        )
+
+        dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
+        dist = walls * max_length + (1 - walls) * dist
+
+        if dist.equal(pred_dist):
+            return dist * (1 - walls)
+
+
+######################################################################
+
+
+def compute_policy(walls, goal_i, goal_j):
+    distance = compute_distance(walls, goal_i, goal_j)
+    distance = distance + walls.numel() * walls
+
+    value = distance.new_full((4,) + distance.size(), walls.numel())
+    value[0, :, 1:] = distance[:, :-1]  # <
+    value[1, :, :-1] = distance[:, 1:]  # >
+    value[2, 1:, :] = distance[:-1, :]  # ^
+    value[3, :-1, :] = distance[1:, :]  # v
+
+    proba = (value.min(dim=0)[0][None] == value).float()
+    proba = proba / proba.sum(dim=0)[None]
+    proba = proba * (1 - walls) + walls.float() / 4
+
+    return proba
+
+
+def stationary_densities(mazes, policies):
+    policies = policies * (mazes != v_goal)[:, None]
+    start = (mazes == v_start).nonzero(as_tuple=True)
+    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
+    pred_probas = probas.clone()
+    probas[start] = 1.0
+
+    while not pred_probas.equal(probas):
+        pred_probas.copy_(probas)
+        probas.zero_()
+        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
+        probas[start] = 1.0
+
+    return probas
+
+
+######################################################################
+
+
+def mark_path(walls, i, j, goal_i, goal_j, policy):
+    action = torch.distributions.categorical.Categorical(
+        policy.permute(1, 2, 0)
+    ).sample()
+    n, nmax = 0, walls.numel()
+    while i != goal_i or j != goal_j:
+        di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
+        i, j = i + di, j + dj
+        assert walls[i, j] == 0
+        walls[i, j] = v_path
+        n += 1
+        assert n < nmax
+
+
+def path_optimality(ref_paths, paths):
+    return (ref_paths == v_path).long().flatten(1).sum(1) == (
+        paths == v_path
+    ).long().flatten(1).sum(1)
+
+
+def path_correctness(mazes, paths):
+    still_ok = (mazes - (paths * (paths != v_path))).view(mazes.size(0), -1).abs().sum(
+        1
+    ) == 0
+    reached = still_ok.new_zeros(still_ok.size())
+    current, pred_current = paths.clone(), paths.new_zeros(paths.size())
+    goal = (mazes == v_goal).long()
+    while not pred_current.equal(current):
+        pred_current.copy_(current)
+        u = (current == v_start).long()
+        possible_next = (
+            u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0
+        ).long()
+        u = u[:, 1:-1, 1:-1]
+        reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * (
+            (current == v_path).sum((1, 2)) == 0
+        )
+        current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + (
+            v_start - v_path
+        ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path))
+        still_ok *= (current == v_start).sum((1, 2)) <= 1
+
+    return still_ok * reached
+
+
+######################################################################
+
+
+def create_maze_data(
+    nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
+):
+    mazes = torch.empty(nb, height, width, dtype=torch.int64)
+    paths = torch.empty(nb, height, width, dtype=torch.int64)
+    policies = torch.empty(nb, 4, height, width)
+
+    for n in progress_bar(range(nb)):
+        maze = create_maze(height, width, nb_walls)
+        i = (maze == v_empty).nonzero()
+        while True:
+            start, goal = i[torch.randperm(i.size(0))[:2]]
+            if (start - goal).abs().sum() >= dist_min:
+                break
+        start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
+
+        policy = compute_policy(maze, goal_i, goal_j)
+        path = maze.clone()
+        mark_path(path, start_i, start_j, goal_i, goal_j, policy)
+        maze[start_i, start_j] = v_start
+        maze[goal_i, goal_j] = v_goal
+        path[start_i, start_j] = v_start
+        path[goal_i, goal_j] = v_goal
+
+        mazes[n] = maze
+        paths[n] = path
+        policies[n] = policy
+
+    return mazes, paths, policies
+
+
+######################################################################
+
+
+def save_image(
+    name,
+    mazes,
+    target_paths=None,
+    predicted_paths=None,
+    path_correct=None,
+    path_optimal=None,
+):
+    colors = torch.tensor(
+        [
+            [255, 255, 255],  # empty
+            [0, 0, 0],  # wall
+            [0, 255, 0],  # start
+            [127, 127, 255],  # goal
+            [255, 0, 0],  # path
+        ]
+    )
+
+    mazes = mazes.cpu()
+
+    c_mazes = (
+        colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+    )
+
+    imgs = c_mazes.unsqueeze(1)
+
+    if target_paths is not None:
+        target_paths = target_paths.cpu()
+
+        c_target_paths = (
+            colors[target_paths.reshape(-1)]
+            .reshape(target_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+
+        imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
+
+    if predicted_paths is not None:
+        predicted_paths = predicted_paths.cpu()
+        c_predicted_paths = (
+            colors[predicted_paths.reshape(-1)]
+            .reshape(predicted_paths.size() + (-1,))
+            .permute(0, 3, 1, 2)
+        )
+        imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
+
+    img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1)
+
+    # NxKxCxHxW
+    if path_optimal is not None:
+        path_optimal = path_optimal.cpu().long().view(-1, 1, 1, 1)
+        img = (
+            img * (1 - path_optimal)
+            + torch.tensor([0, 255, 0]).view(1, -1, 1, 1) * path_optimal
+        )
+
+    if path_correct is not None:
+        path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+        img = img * path_correct + torch.tensor([255, 0, 0]).view(1, -1, 1, 1) * (
+            1 - path_correct
+        )
+
+    img = img.expand(
+        -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+    ).clone()
+
+    print(f"{img.size()=} {imgs.size()=}")
+
+    for k in range(imgs.size(1)):
+        img[
+            :,
+            :,
+            1 : 1 + imgs.size(3),
+            1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+        ] = imgs[:, k]
+
+    img = img.float() / 255.0
+
+    torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
+
+
+######################################################################
+
+if __name__ == "__main__":
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    mazes, paths, policies = create_maze_data(8)
+    mazes, paths = mazes.to(device), paths.to(device)
+    save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths)
+    print(path_correctness(mazes, paths))
+
+######################################################################
diff --git a/memload.py b/memload.py
new file mode 100755 (executable)
index 0000000..5fcd089
--- /dev/null
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+
+import torch
+
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CppExtension
+
+cpp_source = """
+std::vector<torch::Tensor> greedy_lines_allocation(torch::Tensor load_start, float decay, torch::Tensor line_requests) {
+  auto nb_lines = load_start.size(1);
+  auto batch_size = line_requests.size(0);
+  auto nb_heads = line_requests.size(1);
+  auto T = line_requests.size(2);
+
+  auto load_start_a = load_start.accessor<float,2>();
+  auto line_requests_a = line_requests.accessor<float,3>();
+
+  auto load = torch::empty({batch_size, nb_lines, T});
+  auto load_a = load.accessor<float,3>();
+
+  auto allocation_result = torch::empty({batch_size,nb_heads,T},torch::TensorOptions().dtype(torch::kInt64));
+  auto allocation_result_a = allocation_result.accessor<long,3>();
+
+  for(int n = 0; n < batch_size; n++) {
+    for(int t = 0; t < T; t++) {
+      for(int l = 0; l < nb_lines; l++) {
+        if(t == 0) {
+          load[n][l][t] = decay * load_start_a[n][l];
+        } else {
+          load[n][l][t] = decay * load[n][l][t-1];
+        }
+      }
+      for(int h = 0; h < nb_heads; h++) {
+        if(line_requests_a[n][h][t] > 0) {
+          int l_lowest_load;
+          for(int l = 0; l < nb_lines; l++) {
+            if(l == 0 || load_a[n][l][t]<load_a[n][l_lowest_load][t]) l_lowest_load=l;
+          }
+          if(load_a[n][l_lowest_load][t] < line_requests_a[n][h][t]) {
+            allocation_result_a[n][h][t] = l_lowest_load;
+            load_a[n][l_lowest_load][t] = line_requests_a[n][h][t];
+          } else {
+            allocation_result_a[n][h][t] = -1;
+          }
+        } else {
+          allocation_result_a[n][h][t] = -1;
+        }
+      }
+    }
+  }
+
+  return {allocation_result,load};
+}
+"""
+
+######################################################################
+
+allocator_module = torch.utils.cpp_extension.load_inline(
+    name="allocator_module",
+    cpp_sources=[cpp_source],
+    functions=["greedy_lines_allocation"],
+    build_directory="/tmp/",
+    # verbose=True,
+)
+
+lines_allocation = allocator_module.greedy_lines_allocation
+
+######################################################################
+
+if __name__ == "__main__":
+    N, H, L, T = 1, 1, 3, 20
+
+    load_start = torch.rand(N, L)
+    requests = (2 * torch.rand(N, H, T) - 1).clamp(min=0)
+
+    print("load_start", load_start)
+
+    print("requests", requests)
+
+    alloc, load = lines_allocation(load_start, 0.99, requests)
+
+    print("alloc", alloc)
+
+    print("load", load)
diff --git a/mygpt.py b/mygpt.py
new file mode 100755 (executable)
index 0000000..90102bf
--- /dev/null
+++ b/mygpt.py
@@ -0,0 +1,954 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
+import math, warnings
+
+import torch, einops
+
+from torch import nn
+from torch.nn import functional as F
+from functorch.dim import dims
+
+import ffutils
+
+# import memload
+
+######################################################################
+
+# A BracketedSequence is a BxTx... tensor with a first and a nb time
+# steps to compute.
+
+# Modules able to process it expect that they will have to process a
+# first bracket starting at t=0, followed by a succession of brackets
+# that move forward in time, do not overlap, and cover the axis T with
+# no holes.
+#
+# Although it is more general, for a classical prompt-conditioned
+# auto-regressive process it will be a first bracket starting at 0 and
+# of arbitrary length for the "prompt", followed by brackets of length
+# 1 for the successive tokens.
+#
+# Modules able to process brackets may implement a cache that is
+# resetted when the input bracket starts at t=0
+
+
+class BracketedSequence:
+    def __init__(self, x, first=None, nb=None, init_cache=None):
+        self.x = x
+        assert (first is None and nb is None and init_cache is None) or (
+            first is not None and nb is not None and init_cache is not None
+        )
+
+        self.first = 0 if first is None else first
+        self.nb = x.size(1) if nb is None else nb
+        self.init_cache = True if init_cache is None else init_cache
+
+    def slice(self):
+        return self.x[:, self.first : self.first + self.nb]
+
+    def complete(self):
+        return self.first == 0 and self.nb == self.x.size(1)
+
+
+######################################################################
+
+
+class CacheWrapper(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        if bs.init_cache:
+            y = self.f(bs.slice())
+            self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
+            self.cache_y[:, bs.first : bs.first + bs.nb] = y
+        else:
+            assert tuple(bs.x.size()[:2]) == tuple(self.cache_y.size()[:2])
+            assert bs.first + bs.nb <= self.cache_y.size(1)
+            self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
+
+        return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
+
+
+##############################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.init_cache)
+
+
+##############################
+
+
+class AddPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
+
+    def forward(self, bs):
+        if bs.init_cache:
+            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
+                :, None
+            ]
+            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+                None, :
+            ]
+            k = j % 2
+            self.pe = torch.sin(
+                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+            )
+            self.cache_y = bs.x.new(bs.x.size())
+
+        self.cache_y[:, bs.first : bs.first + bs.nb] = (
+            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+        )
+
+        return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
+
+
+import pscan
+
+
+# X is /.../xTxD   A is /.../xT   Y_init is /.../xD
+
+
+def pscan_dim(A, X, Y_init, dim=-2):
+    s = X.size()
+    a, T, b = s[:dim].numel(), s[dim], s[dim + 1 :].numel()
+
+    A = A.reshape(a, T, *s[dim + 1 : -1])
+    X = X.reshape(a, T, *s[dim + 1 : -1], -1)
+
+    if Y_init is None:
+        Y_init = X.new_zeros(a, *s[dim + 1 : -1], X.size(-1))
+    else:
+        Y_init = Y_init.reshape(a, *s[dim + 1 : -1], -1)
+
+    Y = pscan.pscan(A, X, Y_init).reshape(s)
+
+    return Y
+
+
+def pscan_shape(A, X, Y_init):
+    s = X.size()
+    A = A.reshape(-1, s[-2])
+    X = X.reshape(-1, s[-2], s[-1])
+
+    if Y_init is None:
+        Y_init = X.new_zeros(X.size(0), s[-1])
+    else:
+        Y_init = Y_init.reshape(-1, s[-1])
+
+    Y = pscan.pscan(A, X, Y_init).reshape(s)
+
+    return Y
+
+
+def nsum_shape(X, Y_init):
+    s = X.size()
+    X = X.reshape(-1, s[-2], s[-1])  # ntd
+
+    Y = 0 if Y_init is None else Y_init.reshape(-1, s[-1])
+    result = []
+
+    for k in range(X.size(1)):
+        Y = Y + X[:, k]
+        Y = Y / Y.norm(dim=-1, keepdim=True).clamp(min=1)
+        result.append(Y)
+
+    return torch.cat(result, dim=1).reshape(s)
+
+
+##############################
+
+
+class DumbRec(nn.Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads,
+        nb_lines,
+        attention_dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.nb_lines = nb_lines
+        self.attention_dropout = attention_dropout
+
+        self.k_star = randw(nb_lines, dim_qk)
+
+        self.w_qw = randw(nb_heads, dim_qk, dim_in)
+        self.w_qr = randw(nb_heads, dim_qk, dim_in)
+        # self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
+
+    def reset_inner_loss(self):
+        self.acc_attention = 0
+        self.acc_nb = 0
+
+    def get_inner_loss(self):
+        warnings.warn("l2 regularization", RuntimeWarning)
+        return (self.acc_attention / self.acc_nb).pow(2).sum()
+        # return torch.tensor([0], device=self.w_qw.device)
+
+    def forward(self, bs):
+        x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
+
+        if bs.init_cache:
+            self.rec_v = x_q.new_zeros(
+                x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
+            )
+            # self.rec_k = x_q.new_zeros(
+            # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
+            # )
+            self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+        ######################################################################
+        # Prepare the keys
+
+        k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
+
+        warnings.warn("rotating key barrel", RuntimeWarning)
+        k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
+        t_barrel = torch.arange(t0, t1, device=k_star.device)
+        t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
+        l_barrel = (
+            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+        ) % k_star.size(0)
+        k_star = k_star[l_barrel, t_barrel]
+
+        ######################################################################
+        # Compute the recurrent state
+
+        qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
+
+        v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
+        # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
+
+        aw = torch.einsum(
+            "nhtd,ltd->nhlt",
+            qw,
+            k_star,
+        ) / math.sqrt(self.w_qw.size(1))
+
+        aw = aw.softmax(dim=2)  # nhlt
+
+        if self.train:
+            self.acc_attention += aw.sum(dim=(0, 1, 3))
+            self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
+
+        aw = F.dropout(aw, self.attention_dropout, self.training)
+
+        A = 1 - aw.sum(dim=1)  # nlt
+
+        V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
+        # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
+
+        if t0 == 0:
+            V0 = None
+            # K0 = None
+        else:
+            V0 = self.rec_v[:, :, t0 - 1]
+            # K0 = self.rec_k[:, :, t0 - 1]
+
+        self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
+        # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
+
+        ######################################################################
+        # compute the readout
+
+        qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
+
+        ar = torch.einsum(
+            "nhtd,ld->nhlt",
+            qr,
+            # self.rec_k[:, :, t0:t1],
+            self.k_star,
+        ) / math.sqrt(self.w_qr.size(1))
+
+        ar = ar.softmax(dim=2)  # nhlt
+
+        ar = F.dropout(ar, self.attention_dropout, self.training)
+
+        y = torch.einsum(
+            "nhlt,nltd->nthd",
+            ar,
+            self.rec_v[:, :, t0:t1],
+        ).flatten(2)
+
+        self.cache_y[:, t0:t1] = y @ self.w_o
+
+        return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
+
+
+##############################
+
+
+class KVRec(nn.Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads,
+        nb_lines,
+        attention_dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.nb_lines = nb_lines
+        self.attention_dropout = attention_dropout
+
+        self.k_star = randw(nb_lines, dim_qk)
+
+        self.w_qw = randw(nb_heads, dim_qk, dim_in)
+        self.w_qr = randw(nb_heads, dim_qk, dim_in)
+        self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
+
+    def reset_inner_loss(self):
+        self.acc_attention = 0
+        self.acc_nb = 0
+
+    def get_inner_loss(self):
+        warnings.warn("l2 regularization", RuntimeWarning)
+        return (self.acc_attention / self.acc_nb).pow(2).sum()
+        # return torch.tensor([0], device=self.w_qw.device)
+        # warnings.warn("side regularization", RuntimeWarning)
+        # return (
+        # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
+        # )
+        # return torch.tensor([0], device=self.w_qw.device)
+
+    def forward(self, bs):
+        x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
+
+        # n,h,l,t,d = dims(5)
+
+        if bs.init_cache:
+            self.rec_v = x_q.new_zeros(
+                x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
+            )
+            self.rec_k = x_q.new_zeros(
+                x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
+            )
+            self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+        ######################################################################
+        # Prepare the keys
+
+        k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
+
+        warnings.warn("rotating key barrel", RuntimeWarning)
+        k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
+        t_barrel = torch.arange(t0, t1, device=k_star.device)
+        t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
+        l_barrel = (
+            torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+        ) % k_star.size(0)
+        k_star = k_star[l_barrel, t_barrel]
+
+        ######################################################################
+        # Compute the recurrent state
+
+        qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
+
+        v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
+        k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
+
+        aw = torch.einsum(
+            "nhtd,ltd->nhlt",
+            qw,
+            k_star,
+        ) / math.sqrt(self.w_qw.size(1))
+
+        aw = aw.softmax(dim=2)  # nhlt
+
+        if self.train:
+            # We want all the memory lines to be used similarly
+            self.acc_attention += aw.sum(dim=(0, 1, 3))  # Sum accross NxHx_xT
+            self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
+
+        aw = F.dropout(aw, self.attention_dropout, self.training)
+
+        A = 1 - aw.sum(dim=1)  # nlt
+
+        V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
+        K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
+
+        if t0 == 0:
+            V0 = None
+            K0 = None
+        else:
+            V0 = self.rec_v[:, :, t0 - 1]
+            K0 = self.rec_k[:, :, t0 - 1]
+
+        self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
+        self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
+
+        ######################################################################
+        # compute the readout
+
+        qr = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qr)
+
+        ar = torch.einsum(
+            "nhtd,nltd->nhlt",
+            qr,
+            self.rec_k[:, :, t0:t1],
+        ) / math.sqrt(self.w_qr.size(1))
+
+        ar = ar.softmax(dim=2)  # nhlt
+
+        ar = F.dropout(ar, self.attention_dropout, self.training)
+
+        y = torch.einsum(
+            "nhlt,nltd->nthd",
+            ar,
+            self.rec_v[:, :, t0:t1],
+        ).flatten(2)
+
+        self.cache_y[:, t0:t1] = y @ self.w_o
+
+        return BracketedSequence(self.cache_y, t0, t1 - t0, bs.init_cache)
+
+
+##############################
+
+
+def moving_window(x, dim, win_dim, win_size):
+    size, stride = x.size(), x.stride()
+    size = size[:dim] + (size[dim] - win_size + 1,) + size[dim + 1 :]
+    size = size[:win_dim] + (win_size,) + size[win_dim:]
+    stride = stride[:win_dim] + (stride[dim],) + stride[win_dim:]
+
+    return x.as_strided(size=size, stride=stride)
+
+
+##############################
+
+
+class Caterpillar(nn.Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads,
+        caterpillar_length,
+        caterpillar_height,
+        attention_dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        warnings.warn("Caterpillar", RuntimeWarning)
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.caterpillar_length = caterpillar_length
+        self.caterpillar_height = caterpillar_height
+        self.attention_dropout = attention_dropout
+
+        self.w_G = randw(nb_heads, caterpillar_height, dim_in)
+        self.b_G = nn.Parameter(
+            torch.full(
+                (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
+            )
+        )
+
+        self.w_K = randw(nb_heads, dim_qk, dim_in)
+        self.w_V = randw(nb_heads, dim_v, dim_in)
+        self.w_Q = randw(nb_heads, dim_qk, dim_in)
+        self.w_O = randw(dim_v * nb_heads, dim_in)
+
+        self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
+        self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
+
+    def reset_inner_loss(self):
+        self.acc_attention = 0
+        self.acc_nb = 0
+
+    def get_inner_loss(self):
+        # warnings.warn("l2 regularization", RuntimeWarning)
+        # return (self.acc_attention / self.acc_nb).pow(2).sum()
+        return torch.tensor([0], device=self.w_Q.device)
+
+    def forward(self, bs):
+        # Dimensions to make the source a bit clearer, that's needed
+
+        X, t0, t1 = bs.slice(), bs.first, bs.first + bs.nb
+
+        N = bs.x.size(0)
+        T = bs.x.size(1)
+        DV = self.w_V.size(1)
+        DK = self.w_K.size(1)
+        Dout = self.w_O.size(1)
+        CH = self.caterpillar_height
+        CL = self.caterpillar_length
+
+        assert (
+            t0 >= CL and (t1 - t0) % CL == 0
+        ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length"
+
+        if bs.init_cache:
+            self.rec_V = X.new_zeros(N, CH, T, DV)
+            self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :]
+            self.rec_K = X.new_zeros(N, CH, T, DK)
+            self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :]
+            self.cache_Y = X.new_zeros(N, T, Dout)
+
+        ######################################################################
+        # Compute the recurrent state
+
+        G = (
+            torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
+        ).sigmoid()
+
+        V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
+        K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+
+        A = 1 - G.sum(1)
+        gated_V = torch.einsum("nhet,nhtd->netd", G, V)
+        gated_K = torch.einsum("nhet,nhtd->netd", G, K)
+
+        init_rec_V = self.rec_V[:, :, t0 - CL : t0]
+        init_rec_K = self.rec_K[:, :, t0 - CL : t0]
+
+        A = A.unflatten(2, (-1, CL))
+        gated_V = gated_V.unflatten(2, (-1, CL))
+        gated_K = gated_K.unflatten(2, (-1, CL))
+
+        next_V = pscan_dim(A, gated_V, init_rec_V, dim=2)
+        next_K = pscan_dim(A, gated_K, init_rec_K, dim=2)
+
+        self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3)
+        self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3)
+
+        ######################################################################
+        # compute the readout
+
+        Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
+
+        uv = moving_window(
+            self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
+        )
+
+        uk = moving_window(
+            self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL
+        )
+
+        ar = torch.einsum(
+            "nhtd,nftld->nhtfl",
+            Q,
+            uk,
+        ) / math.sqrt(DK)
+
+        ar = ar.flatten(3).softmax(dim=3).view(ar.size())
+
+        ar = F.dropout(ar, self.attention_dropout, self.training)
+
+        Y = torch.einsum(
+            "nhtfl,nftld->nthd",
+            ar,
+            uv,
+        ).flatten(2)
+
+        self.cache_Y[:, t0:t1] = Y @ self.w_O
+
+        return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
+
+
+##############################
+
+
+class QKVAttention(nn.Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        causal=False,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.causal = causal
+        self.attention_dropout = attention_dropout
+        self.record_attention = False
+
+        self.w_q = randw(nb_heads, dim_qk, dim_in)
+        self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
+
+    def forward(self, bs):
+        x_q = bs.x
+
+        assert (
+            self.causal or bs.complete()
+        ), "Partial evaluation is only possible for causal models"
+
+        if bs.init_cache:
+            self.cache_k = x_q.new_zeros(
+                x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
+            )
+            self.cache_v = x_q.new_zeros(
+                x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
+            )
+            self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_q)
+
+        self.cache_k[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
+            "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_k
+        )
+        self.cache_v[:, :, bs.first : bs.first + bs.nb] = torch.einsum(
+            "ntc,hdc->nhtd", x_q[:, bs.first : bs.first + bs.nb], self.w_v
+        )
+
+        a = torch.einsum(
+            "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs.first + bs.nb]
+        ) / math.sqrt(self.w_q.size(1))
+
+        if self.causal:
+            if bs.init_cache:
+                self.cache_attzero = (
+                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
+                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
+                )
+            a = a.masked_fill(
+                self.cache_attzero[
+                    :, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
+                ],
+                float("-inf"),
+            )
+
+        a = a.softmax(dim=3)
+
+        if self.record_attention:
+            self.a = a
+
+        a = F.dropout(a, self.attention_dropout, self.training)
+
+        y = torch.einsum(
+            "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs.first + bs.nb]
+        ).flatten(2)
+
+        self.cache_y[:, bs.first : bs.first + bs.nb] = y @ self.w_o
+
+        return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.init_cache)
+
+
+##############################
+
+
+class MyGPT(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        nb_lines=None,
+        caterpillar_height=None,
+        dim_rec_v=-1,
+        causal=False,
+        dropout=0.0,
+        len_max=1e5,
+        attention_layer="kvrec",
+    ):
+        super().__init__()
+
+        assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
+
+        if attention_layer == "caterpillar":
+            assert nb_lines % caterpillar_height == 0
+            self.caterpillar_length = nb_lines // caterpillar_height
+            self.caterpillar_height = caterpillar_height
+        else:
+            self.caterpillar_length = -1
+            self.caterpillar_height = -1
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+            AddPositionalEncoding(len_max),
+        )
+
+        trunk_blocks = []
+
+        def attlayer():
+            if attention_layer == "mha":
+                return QKVAttention(
+                    dim_in=dim_model,
+                    dim_qk=dim_keys,
+                    dim_v=dim_model // nb_heads,
+                    nb_heads=nb_heads,
+                    causal=causal,
+                    attention_dropout=dropout,
+                )
+            elif attention_layer == "dumbrec":
+                return DumbRec(
+                    dim_in=dim_model,
+                    dim_qk=dim_keys,
+                    dim_v=dim_rec_v,
+                    nb_heads=nb_heads,
+                    nb_lines=nb_lines,
+                    attention_dropout=dropout,
+                )
+            elif attention_layer == "kvrec":
+                return KVRec(
+                    dim_in=dim_model,
+                    dim_qk=dim_keys,
+                    dim_v=dim_rec_v,
+                    nb_heads=nb_heads,
+                    nb_lines=nb_lines,
+                    attention_dropout=dropout,
+                )
+            elif attention_layer == "caterpillar":
+                return Caterpillar(
+                    dim_in=dim_model,
+                    dim_qk=dim_keys,
+                    dim_v=dim_rec_v,
+                    nb_heads=nb_heads,
+                    caterpillar_length=self.caterpillar_length,
+                    caterpillar_height=self.caterpillar_height,
+                    attention_dropout=dropout,
+                )
+            else:
+                raise ValueError(f"Unknown attention type {attention_layer}.")
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    attlayer(),
+                ),
+                WithResidual(
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                        nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                        nn.ReLU(),
+                        nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                        nn.Dropout(dropout),
+                    ),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = CacheWrapper(
+            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+        )
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+        self.reset_inner_loss()
+
+    def forward(self, bs):
+        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.init_cache)
+
+        # To make the code simpler in the Caterpillar layer, we pad
+        # here. It's unclear if/how much it hurts computationaly by
+        # increasing the sequence length for the other layers
+
+        if self.caterpillar_length > 0:
+            original_nb = bs.nb
+            if bs.nb % self.caterpillar_length > 0:
+                bs.nb += self.caterpillar_length - bs.nb % self.caterpillar_length
+
+            bs = BracketedSequence(
+                F.pad(bs.x, (self.caterpillar_length, self.caterpillar_length)),
+                bs.first + self.caterpillar_length,
+                bs.nb,
+                bs.init_cache,
+            )
+
+        bs = self.embedding(bs)
+        bs = self.trunk(bs)
+        bs = self.readout(bs)
+
+        if self.caterpillar_length > 0:
+            bs = BracketedSequence(
+                F.pad(bs.x, (0, 0, -self.caterpillar_length, -self.caterpillar_length)),
+                bs.first - self.caterpillar_length,
+                original_nb,
+                bs.init_cache,
+            )
+
+        return bs
+
+    # ar_mask is a tensor with 0s and 1s, of same shape as input, with
+    # 1s where tokens should be generated. The others are kept
+    # unchanged.
+
+    def masked_inplace_autoregression(
+        self,
+        input_src,
+        ar_mask_src,
+        forbidden_tokens=None,
+        deterministic_synthesis=False,
+    ):
+        input = input_src.to(self.readout.f.weight.device)
+        ar_mask = ar_mask_src.to(self.readout.f.weight.device)
+        to_generate = (ar_mask.sum(0) > 0).nonzero()
+        if to_generate.min() > 0:
+            self(
+                BracketedSequence(input, 0, to_generate.min(), True)
+            )  # Needed to initialize the model's cache
+        for s in range(to_generate.min(), to_generate.max() + 1):
+            output = self(BracketedSequence(input, s, 1, s == 0)).x
+            logits = output[:, s]
+            if forbidden_tokens is not None:
+                logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+            if deterministic_synthesis:
+                t_next = logits.argmax(1)
+            else:
+                dist = torch.distributions.categorical.Categorical(logits=logits)
+                t_next = dist.sample()
+            input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+        input_src.copy_(input)
+
+    def reset_inner_loss(self):
+        for m in self.modules():
+            if m is not self and hasattr(m, "reset_inner_loss"):
+                m.reset_inner_loss()
+
+    def get_inner_loss(self):
+        l = torch.tensor([0.0], device=self.readout.f.weight.device)
+        for m in self.modules():
+            if m is not self and hasattr(m, "get_inner_loss"):
+                l += m.get_inner_loss()
+        return l
+
+    def record_attention(self, v=True):
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                m.record_attention = v
+
+    def retrieve_attention(self):
+        a = []
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                a.append(m.a)
+        return a
+
+
+######################################################################
+
+if __name__ == "__main__":
+    print("Basic check.")
+
+    m = Caterpillar(
+        dim_in=4,
+        dim_qk=3,
+        dim_v=7,
+        nb_heads=1,
+        caterpillar_length=7,
+        caterpillar_height=3,
+        attention_dropout=0.0,
+    )
+
+    m.reset_inner_loss()
+    x = torch.randn(1, 21 + 2 * 7, 4)
+    y1 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
+    y2 = m(BracketedSequence(x, first=7, nb=21, init_cache=True)).x[:, 7:28]
+    y3a = m(BracketedSequence(x, first=7, nb=14, init_cache=True)).x[:, 7:21]
+    y3b = m(BracketedSequence(x, first=21, nb=7, init_cache=False)).x[:, 21:28]
+    print((y1 - y2).abs().max())
+    print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max())
+    exit(0)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    vocabulary_size = 128
+    x = torch.randint(vocabulary_size, (6, 1024))
+
+    model = MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=512,
+        dim_keys=64,
+        dim_hidden=2048,
+        nb_heads=8,
+        nb_lines=128,
+        nb_blocks=12,
+        dropout=0.1,
+        causal=True,
+    )
+
+    x = x.to(device)
+    model.to(device)
+
+    import time, sys
+
+    # import torchvision.models as models
+    # from torch.profiler import profile, record_function, ProfilerActivity
+
+    # with profile(activities=[ProfilerActivity.CPU,  ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
+    # with record_function("model_inference"):
+
+    model.eval()
+    for i in range(3):
+        start_time = time.perf_counter()
+        for k in range(10):
+            model(BracketedSequence(x))
+        duration = time.perf_counter() - start_time
+        print(duration)
+        sys.stdout.flush()
+
+    # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
+    # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+    # print("##############################################################")
+    # y2 = torch.randn_like(y1)
+    # for s in range(x.size(1)):
+    # z = model(BracketedSequence(x, s, 1))
+    # y2[:, s : s + 1] = z.slice()
+
+    # print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
+
+######################################################################
diff --git a/picoclvr.py b/picoclvr.py
new file mode 100755 (executable)
index 0000000..0cd3062
--- /dev/null
@@ -0,0 +1,370 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+import torch, torchvision
+import torch.nn.functional as F
+
+color_name2rgb = {
+    "white": [255, 255, 255],
+    "red": [255, 0, 0],
+    "green": [0, 128, 0],
+    "blue": [0, 0, 255],
+    "yellow": [255, 255, 0],
+    "black": [0, 0, 0],
+    "maroon": [128, 0, 0],
+    "dark_red": [139, 0, 0],
+    "brown": [165, 42, 42],
+    "firebrick": [178, 34, 34],
+    "crimson": [220, 20, 60],
+    "tomato": [255, 99, 71],
+    "coral": [255, 127, 80],
+    "indian_red": [205, 92, 92],
+    "light_coral": [240, 128, 128],
+    "dark_salmon": [233, 150, 122],
+    "salmon": [250, 128, 114],
+    "light_salmon": [255, 160, 122],
+    "orange_red": [255, 69, 0],
+    "dark_orange": [255, 140, 0],
+    "orange": [255, 165, 0],
+    "gold": [255, 215, 0],
+    "dark_golden_rod": [184, 134, 11],
+    "golden_rod": [218, 165, 32],
+    "pale_golden_rod": [238, 232, 170],
+    "dark_khaki": [189, 183, 107],
+    "khaki": [240, 230, 140],
+    "olive": [128, 128, 0],
+    "yellow_green": [154, 205, 50],
+    "dark_olive_green": [85, 107, 47],
+    "olive_drab": [107, 142, 35],
+    "lawn_green": [124, 252, 0],
+    "chartreuse": [127, 255, 0],
+    "green_yellow": [173, 255, 47],
+    "dark_green": [0, 100, 0],
+    "forest_green": [34, 139, 34],
+    "lime": [0, 255, 0],
+    "lime_green": [50, 205, 50],
+    "light_green": [144, 238, 144],
+    "pale_green": [152, 251, 152],
+    "dark_sea_green": [143, 188, 143],
+    "medium_spring_green": [0, 250, 154],
+    "spring_green": [0, 255, 127],
+    "sea_green": [46, 139, 87],
+    "medium_aqua_marine": [102, 205, 170],
+    "medium_sea_green": [60, 179, 113],
+    "light_sea_green": [32, 178, 170],
+    "dark_slate_gray": [47, 79, 79],
+    "teal": [0, 128, 128],
+    "dark_cyan": [0, 139, 139],
+    "aqua": [0, 255, 255],
+    "cyan": [0, 255, 255],
+    "light_cyan": [224, 255, 255],
+    "dark_turquoise": [0, 206, 209],
+    "turquoise": [64, 224, 208],
+    "medium_turquoise": [72, 209, 204],
+    "pale_turquoise": [175, 238, 238],
+    "aqua_marine": [127, 255, 212],
+    "powder_blue": [176, 224, 230],
+    "cadet_blue": [95, 158, 160],
+    "steel_blue": [70, 130, 180],
+    "corn_flower_blue": [100, 149, 237],
+    "deep_sky_blue": [0, 191, 255],
+    "dodger_blue": [30, 144, 255],
+    "light_blue": [173, 216, 230],
+    "sky_blue": [135, 206, 235],
+    "light_sky_blue": [135, 206, 250],
+    "midnight_blue": [25, 25, 112],
+    "navy": [0, 0, 128],
+    "dark_blue": [0, 0, 139],
+    "medium_blue": [0, 0, 205],
+    "royal_blue": [65, 105, 225],
+    "blue_violet": [138, 43, 226],
+    "indigo": [75, 0, 130],
+    "dark_slate_blue": [72, 61, 139],
+    "slate_blue": [106, 90, 205],
+    "medium_slate_blue": [123, 104, 238],
+    "medium_purple": [147, 112, 219],
+    "dark_magenta": [139, 0, 139],
+    "dark_violet": [148, 0, 211],
+    "dark_orchid": [153, 50, 204],
+    "medium_orchid": [186, 85, 211],
+    "purple": [128, 0, 128],
+    "thistle": [216, 191, 216],
+    "plum": [221, 160, 221],
+    "violet": [238, 130, 238],
+    "magenta": [255, 0, 255],
+    "orchid": [218, 112, 214],
+    "medium_violet_red": [199, 21, 133],
+    "pale_violet_red": [219, 112, 147],
+    "deep_pink": [255, 20, 147],
+    "hot_pink": [255, 105, 180],
+    "light_pink": [255, 182, 193],
+    "pink": [255, 192, 203],
+    "antique_white": [250, 235, 215],
+    "beige": [245, 245, 220],
+    "bisque": [255, 228, 196],
+    "blanched_almond": [255, 235, 205],
+    "wheat": [245, 222, 179],
+    "corn_silk": [255, 248, 220],
+    "lemon_chiffon": [255, 250, 205],
+    "light_golden_rod_yellow": [250, 250, 210],
+    "light_yellow": [255, 255, 224],
+    "saddle_brown": [139, 69, 19],
+    "sienna": [160, 82, 45],
+    "chocolate": [210, 105, 30],
+    "peru": [205, 133, 63],
+    "sandy_brown": [244, 164, 96],
+    "burly_wood": [222, 184, 135],
+    "tan": [210, 180, 140],
+    "rosy_brown": [188, 143, 143],
+    "moccasin": [255, 228, 181],
+    "navajo_white": [255, 222, 173],
+    "peach_puff": [255, 218, 185],
+    "misty_rose": [255, 228, 225],
+    "lavender_blush": [255, 240, 245],
+    "linen": [250, 240, 230],
+    "old_lace": [253, 245, 230],
+    "papaya_whip": [255, 239, 213],
+    "sea_shell": [255, 245, 238],
+    "mint_cream": [245, 255, 250],
+    "slate_gray": [112, 128, 144],
+    "light_slate_gray": [119, 136, 153],
+    "light_steel_blue": [176, 196, 222],
+    "lavender": [230, 230, 250],
+    "floral_white": [255, 250, 240],
+    "alice_blue": [240, 248, 255],
+    "ghost_white": [248, 248, 255],
+    "honeydew": [240, 255, 240],
+    "ivory": [255, 255, 240],
+    "azure": [240, 255, 255],
+    "snow": [255, 250, 250],
+    "silver": [192, 192, 192],
+    "gainsboro": [220, 220, 220],
+    "white_smoke": [245, 245, 245],
+}
+
+color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
+color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
+
+######################################################################
+
+
+def all_properties(height, width, nb_squares, square_i, square_j, square_c):
+    s = []
+
+    for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
+        s += [f"there is {c_r}"]
+
+        if square_i[r] >= height - height // 3:
+            s += [f"{c_r} bottom"]
+        if square_i[r] < height // 3:
+            s += [f"{c_r} top"]
+        if square_j[r] >= width - width // 3:
+            s += [f"{c_r} right"]
+        if square_j[r] < width // 3:
+            s += [f"{c_r} left"]
+
+        for t, c_t in [
+            (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
+        ]:
+            if square_i[r] > square_i[t]:
+                s += [f"{c_r} below {c_t}"]
+            if square_i[r] < square_i[t]:
+                s += [f"{c_r} above {c_t}"]
+            if square_j[r] > square_j[t]:
+                s += [f"{c_r} right of {c_t}"]
+            if square_j[r] < square_j[t]:
+                s += [f"{c_r} left of {c_t}"]
+
+    return s
+
+
+######################################################################
+
+# Generates sequences
+
+
+def generate(
+    nb,
+    height,
+    width,
+    max_nb_squares=5,
+    max_nb_properties=10,
+    nb_colors=5,
+    pruner=None,
+):
+    assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
+
+    descr = []
+
+    for n in range(nb):
+        # we want uniform over the combinations of 1 to max_nb_squares
+        # pixels of nb_colors
+        logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        nb_squares = dist.sample((1,)) + 1
+        # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+        square_position = torch.randperm(height * width)[:nb_squares]
+
+        # color 0 is white and reserved for the background
+        square_c = torch.randperm(nb_colors)[:nb_squares] + 1
+        square_i = square_position.div(width, rounding_mode="floor")
+        square_j = square_position % width
+
+        img = torch.zeros(height * width, dtype=torch.int64)
+        for k in range(nb_squares):
+            img[square_position[k]] = square_c[k]
+
+        # generates all the true properties
+
+        s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
+
+        if pruner is not None:
+            s = list(filter(pruner, s))
+
+        # picks at most max_nb_properties at random
+
+        nb_properties = torch.randint(max_nb_properties, (1,)) + 1
+        s = (
+            " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
+            + " <img> "
+            + " ".join([f"{color_id2name[n.item()]}" for n in img])
+        )
+
+        descr += [s]
+
+    return descr
+
+
+######################################################################
+
+# Extracts the image after <img> in descr as a 1x3xHxW tensor
+
+
+def descr2img(descr, height, width):
+    result = []
+
+    def token2color(t):
+        try:
+            return color_name2rgb[t]
+        except KeyError:
+            return [128, 128, 128]
+
+    for d in descr:
+        d = d.split("<img>")[1]
+        d = d.strip().split(" ")[: height * width]
+        d = d + ["<unk>"] * (height * width - len(d))
+        d = [token2color(t) for t in d]
+        img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
+        result.append(img)
+
+    return torch.cat(result, 0)
+
+
+######################################################################
+
+# Returns all the properties of the image after <img> in descr
+
+
+def descr2properties(descr, height, width):
+    if type(descr) == list:
+        return [descr2properties(d, height, width) for d in descr]
+
+    d = descr.split("<img>")
+    img_tokens = d[-1] if len(d) > 1 else ""
+    img_tokens = img_tokens.strip().split(" ")[: height * width]
+    if len(img_tokens) != height * width:
+        return []
+
+    seen = {}
+    for k, x in enumerate(img_tokens):
+        if x != color_id2name[0]:
+            if x in color_name2rgb:
+                if x in seen:
+                    return []
+            else:
+                return []
+            seen[x] = (color_name2id[x], k // width, k % width)
+
+    square_infos = tuple(zip(*seen.values()))
+
+    if square_infos:
+        square_c = torch.tensor(square_infos[0])
+        square_i = torch.tensor(square_infos[1])
+        square_j = torch.tensor(square_infos[2])
+    else:
+        square_c = torch.tensor([])
+        square_i = torch.tensor([])
+        square_j = torch.tensor([])
+
+    s = all_properties(height, width, len(seen), square_i, square_j, square_c)
+
+    return s
+
+
+######################################################################
+
+# Returns a triplet composed of (1) the total number of properties
+# before <img> in descr, (2) the total number of properties the image
+# after <img> verifies, and (3) the number of properties in (1) not in
+# (2)
+
+
+def nb_properties(descr, height, width, pruner=None):
+    if type(descr) == list:
+        return [nb_properties(d, height, width, pruner) for d in descr]
+
+    d = descr.split("<img>", 1)
+    if len(d) == 0:
+        return 0
+    d = d[0].strip().split("<sep>")
+    d = [x.strip() for x in d]
+
+    all_properties = set(descr2properties(descr, height, width))
+
+    if pruner is None:
+        requested_properties = set(d)
+    else:
+        requested_properties = set(filter(pruner, d))
+
+    missing_properties = requested_properties - all_properties
+
+    return (len(requested_properties), len(all_properties), len(missing_properties))
+
+
+######################################################################
+
+if __name__ == "__main__":
+    for n in range(16):
+        descr = generate(nb=1, height=12, width=16)
+
+        print(nb_properties(descr, height=12, width=16))
+
+        with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
+            for d in descr:
+                f.write(f"{d}\n\n")
+
+        img = descr2img(descr, height=12, width=16)
+        if img.size(0) == 1:
+            img = F.pad(img, (1, 1, 1, 1), value=64)
+
+        torchvision.utils.save_image(
+            img / 255.0,
+            f"picoclvr_example_{n:02d}.png",
+            padding=1,
+            nrow=4,
+            pad_value=0.8,
+        )
+
+    import time
+
+    start_time = time.perf_counter()
+    descr = generate(nb=1000, height=12, width=16)
+    end_time = time.perf_counter()
+    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+
+######################################################################
diff --git a/problems.py b/problems.py
new file mode 100755 (executable)
index 0000000..9e368c2
--- /dev/null
@@ -0,0 +1,490 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class Problem:
+    def generate_sequences(self, nb):
+        pass
+
+    def seq2str(self, seq):
+        return "[NOT IMPLEMENTED]"
+
+    def compute_nb_correct(self, input, ar_mask, result):
+        nb_total = ar_mask.sum().item()
+        nb_correct = ((result == input).long() * ar_mask).sum().item()
+        return nb_total, nb_correct
+
+
+####################
+
+
+class ProblemDegradation(Problem):
+    def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
+        assert value_max // nb_state_tokens >= 2
+        self.nb_state_tokens = nb_state_tokens
+        self.nb_time_steps = nb_time_steps
+        self.value_max = value_max
+        self.hard = hard
+
+    def generate_sequences(self, nb):
+        x = (
+            torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
+        ).long() * self.value_max
+        seq = [x]
+
+        for t in range(self.nb_time_steps - 1):
+            v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
+            u = (v.max(dim=-1, keepdim=True).values == v).long()
+            n = (
+                (u * x)
+                .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
+                .sum(dim=-1, keepdim=True)
+            )
+            m = 1 + ((n - 1) * torch.rand(n.size())).long()
+            x = (
+                x
+                + m * u.roll(shifts=-1, dims=-1)
+                - n * u
+                + (n - m) * u.roll(shifts=1, dims=-1)
+            )
+            seq.append(x)
+
+        if self.hard:
+            seq.reverse()
+
+        seq = torch.cat(seq, dim=1)
+        return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+    def compute_nb_correct(self, input, ar_mask, result):
+        nb_total = result.size(0)
+        nb_correct = 0
+        e = result.new_zeros(self.nb_state_tokens)
+
+        for seq in result:
+            states = list(seq.split(self.nb_state_tokens))
+            if self.hard:
+                states.reverse()
+
+            d = states[0]
+            j = d.sort(descending=True).indices[0]
+            e.zero_()
+            e[j] = self.value_max
+            if (d - e).abs().sum() == 0:
+                nb_errors = 0
+                for k in range(len(states) - 1):
+                    d = states[k + 1] - states[k]
+                    j = d.sort(descending=False).indices[0]
+                    if (
+                        d[j] == 0
+                        or d[j] > self.value_max // 4
+                        or d[(j + 1) % e.size(0)] <= 0
+                        or d[(j + 1) % e.size(0)] >= -d[j]
+                    ):
+                        nb_errors += 1
+                    else:
+                        e.zero_()
+                        e[j] = d[j]
+                        e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
+                        e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
+                        if (d - e).abs().sum() > 0:
+                            nb_errors += 1
+                if nb_errors == 0:
+                    nb_correct += 1
+
+        return nb_total, nb_correct
+
+    def seq2str(self, seq):
+        return " | ".join(
+            [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
+        )
+
+
+####################
+
+
+class ProblemMemory(Problem):
+    def __init__(self, len_total=32):
+        self.len_total = len_total
+        self.max_len_pattern = 5
+        self.nb_noise_tokens = 10
+        self.start_pattern_token = 0
+        self.end_pattern_token = 1
+        self.start_result_token = 2
+        self.end_result_token = 3
+        self.token_string = "[]<>" + "".join(
+            [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
+        )
+
+    def generate_sequences(self, nb):
+        sequences = (
+            torch.randint(self.nb_noise_tokens, (nb, self.len_total))
+            + self.end_result_token
+            + 1
+        )
+        len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
+        pattern_positions = torch.randint(
+            self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
+        )
+        k = self.len_total - (3 + self.max_len_pattern)
+        for i in range(nb):
+            l = len_patterns[i]
+            j = pattern_positions[i]
+            sequences[i, j] = self.start_pattern_token
+            sequences[i, j + l + 2] = self.end_pattern_token
+            sequences[i, k] = self.start_result_token
+            sequences[i, k + l + 2] = self.end_result_token
+            sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
+
+        j = torch.arange(self.len_total)[None, :]
+        ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
+
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join(self.token_string[x.item()] for x in seq)
+
+
+class ProblemTwoTargets(Problem):
+    def __init__(self, len_total=10, len_targets=3):
+        assert len_targets >= 3
+        assert len_total >= 3 * len_targets - 1
+        self.len_total = len_total
+        self.len_targets = len_targets
+
+    def generate_sequences(self, nb):
+        k = torch.arange(self.len_total)[None, :]
+        s = torch.randint(10, (nb, self.len_total))
+        l = torch.rand(nb, self.len_total)
+        l = l * (k <= self.len_total - self.len_targets).long()
+        k1 = l.argmax(dim=1, keepdim=True)
+        m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
+        s = s * m + 10 * (1 - m)
+        l = l * (
+            1
+            - (k + self.len_targets - 1 >= k1).long()
+            * (k < k1 + self.len_targets).long()
+        )
+        k2 = l.argmax(dim=1, keepdim=True)
+        m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
+        s = s * m + 11 * (1 - m)
+        a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
+        a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
+        sequences = torch.cat(
+            (
+                s,
+                torch.full((nb, 1), 12),
+                a1,
+                torch.full((nb, 1), 12),
+                a2,
+                torch.full((nb, 1), 12),
+            ),
+            1,
+        )
+        ar_mask = (sequences == 12).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789-+|"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemByHeart(Problem):
+    def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
+        self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
+        self.seq[:, len_prompt] = 10
+
+    def generate_sequences(self, nb):
+        sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+        ar_mask = (sequences == 10).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789|"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemLearnOperator(Problem):
+    def __init__(self, nb_operators=100, len_source=6, len_result=9):
+        self.len_source = len_source
+        self.len_result = len_result
+        self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
+        self.operators = F.one_hot(
+            torch.rand(nb_operators, len_result, len_source).argmax(-1),
+            num_classes=len_source,
+        )
+
+    def generate_sequences(self, nb):
+        nb_operators = torch.randint(self.operators.size(0), (nb,))
+        operators = self.operators[nb_operators]
+        nb_operators = (
+            nb_operators[:, None]
+            // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
+        ) % 10
+        marker1 = torch.full((nb, 1), 10)
+        source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+        marker2 = torch.full((nb, 1), 11)
+        result = operators.bmm(source[:, :, None]).squeeze(-1)
+        sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
+        ar_mask = (sequences == 11).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemGuessOperator(Problem):
+    def __init__(self, len_source=5, len_result=8):
+        self.len_source = len_source
+        self.len_result = len_result
+
+    def generate_sequences(self, nb):
+        operators = F.one_hot(
+            torch.rand(nb, self.len_result, self.len_source).argmax(-1),
+            num_classes=self.len_source,
+        )
+        source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+        marker1 = torch.full((nb, 1), 10)
+        result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
+        marker2 = torch.full((nb, 1), 11)
+        source2 = torch.randint(10, (nb, self.len_source))
+        marker3 = torch.full((nb, 1), 12)
+        result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
+
+        sequences = torch.cat(
+            (source1, marker1, result1, marker2, source2, marker3, result2), 1
+        )
+        ar_mask = (sequences == 12).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789>|~"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemAddition(Problem):
+    def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
+        self.nb_digits = nb_digits
+        self.zero_padded = zero_padded
+        self.inverted_result = inverted_result
+        self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
+        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+    def tensorize(self, strings):
+        len_max = max([len(x) for x in strings])
+        return torch.cat(
+            [
+                torch.tensor(
+                    [
+                        [self.char2id[c] for c in s + "$" * (len_max - len(s))]
+                        for s in strings
+                    ]
+                )
+            ],
+            0,
+        )
+
+    def generate_sequences(self, nb):
+        sequences = []
+        for k in range(nb):
+            a, b = torch.randint(10**self.nb_digits, (2,))
+            c = a + b
+            a, b, c = str(a.item()), str(b.item()), str(c.item())
+            if self.zero_padded:
+                a = "0" * (self.nb_digits - len(a)) + a
+                b = "0" * (self.nb_digits - len(b)) + b
+                c = "0" * (self.nb_digits + 1 - len(c)) + c
+            if self.inverted_result:
+                c = c[::-1]
+            sequences.append(f"{a}+{b}={c}$")
+
+        sequences = self.tensorize(sequences)
+        ar_mask = (sequences == self.char2id["="]).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join(self.id2char[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemMixing(Problem):
+    def __init__(
+        self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
+    ):
+        self.height = height
+        self.width = width
+        self.nb_time_steps = nb_time_steps
+        self.hard = hard
+        self.random_start = random_start
+
+    def start_random(self, nb):
+        y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
+
+        if self.random_start:
+            i = (
+                torch.arange(self.height)
+                .reshape(1, -1, 1)
+                .expand(nb, self.height, self.width)
+            )
+            j = (
+                torch.arange(self.width)
+                .reshape(1, 1, -1)
+                .expand(nb, self.height, self.width)
+            )
+
+            ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
+            rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
+
+            m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+
+            y = y * m + self.height * self.width * (1 - m)
+
+        y = y.reshape(nb, self.height, self.width)
+
+        return y
+
+    def start_error(self, x):
+        if self.random_start:
+            i = (
+                torch.arange(self.height, device=x.device)
+                .reshape(1, -1, 1)
+                .expand_as(x)
+            )
+            j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
+
+            ri = (
+                (x == self.height * self.width)
+                .long()
+                .sum(dim=-1)
+                .argmax(-1)
+                .view(-1, 1, 1)
+            )
+            rj = (
+                (x == self.height * self.width)
+                .long()
+                .sum(dim=-2)
+                .argmax(-1)
+                .view(-1, 1, 1)
+            )
+
+            m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+        else:
+            m = 1
+
+        x = x.flatten(1)
+        u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
+
+        d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
+
+        return d
+
+    def moves(self, x):
+        y = (
+            x[:, None, :, :]
+            .expand(-1, self.height * 2 + self.width * 2, -1, -1)
+            .clone()
+        )
+        k = 0
+
+        for i in range(self.height):
+            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
+            k += 1
+            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
+            k += 1
+
+        for j in range(self.width):
+            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
+            k += 1
+            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
+            k += 1
+
+        return y
+
+    def generate_sequences(self, nb):
+        x = self.start_random(nb)
+
+        seq = [x.flatten(1)]
+
+        for t in range(self.nb_time_steps - 1):
+            y = self.moves(x)
+            x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+            seq.append(x.flatten(1))
+
+        if self.hard:
+            seq.reverse()
+
+        seq = torch.cat(seq, dim=1)
+        return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+    def compute_nb_correct(self, input, ar_mask, result):
+        a = [
+            x.reshape(result.size(0), self.height, self.width)
+            for x in result.split(self.height * self.width, dim=1)
+        ]
+        if self.hard:
+            a.reverse()
+
+        x = a[0]
+
+        d = self.start_error(x)
+
+        for t in range(self.nb_time_steps - 1):
+            x0, x = a[t], a[t + 1]
+            y = self.moves(x0)
+            d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+        nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
+
+        return nb_total, nb_correct
+
+    def seq2str(self, seq):
+        return " | ".join(
+            [
+                " ".join(
+                    [
+                        "-".join(
+                            [
+                                f"{x:02d}" if x < self.height * self.width else "**"
+                                for x in s
+                            ]
+                        )
+                        for s in r.split(self.width)
+                    ]
+                )
+                for r in seq.split(self.height * self.width)
+            ]
+        )
+
+
+####################
+
+if __name__ == "__main__":
+    p = ProblemMixing(height=3, width=3, random_start=False)
+
+    s, m = p.generate_sequences(10000)
+    for x in s[:5]:
+        print(p.seq2str(x))
+    print(p.compute_nb_correct(None, None, s))
diff --git a/pscan.py b/pscan.py
new file mode 100755 (executable)
index 0000000..0ec7b13
--- /dev/null
+++ b/pscan.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch
+
+######################################################################
+
+
+class PScan(torch.autograd.Function):
+    # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T),
+    # and O(log(T)) if not core-bounded, so that
+    #
+    # Y[:, 0] = Y_init
+    # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
+    #
+    # can be computed as
+    #
+    # Y[:, t] = A[:, t] * Y_init + X[:, t]
+
+    @staticmethod
+    def expand_(A, X):
+        if A.size(1) == 1:
+            return
+        T = 2 * (A.size(1) // 2)
+        Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1)
+        Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1))
+        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
+        Aa[:, :, 1].mul_(Aa[:, :, 0])
+        PScan.expand_(Aa[:, :, 1], Xa[:, :, 1])
+        Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1]))
+        Aa[:, 1:, 0].mul_(Aa[:, :-1, 1])
+        if T < A.size(1):
+            X[:, -1].add_(A[:, -1].mul(X[:, -2]))
+            A[:, -1].mul_(A[:, -2])
+
+    @staticmethod
+    def acc_rev_(A, X):
+        if X.size(1) == 1:
+            return
+        T = 2 * (X.size(1) // 2)
+        Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1)
+        Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1))
+        Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1]))
+        B = Aa[:, :, 0].clone()
+        B[:, 1:].mul_(Aa[:, :-1, 1])
+        PScan.acc_rev_(B, Xa[:, :, 0])
+        Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0]))
+        if T < A.size(1):
+            X[:, 0].add_(A[:, 1].mul(X[:, 1]))
+
+    # A is NxT, X is NxTxD, Y_init is NxD
+    #
+    # returns Y of same shape as X, with
+    #
+    # Y[:, t] = A[:, 0] * Y_init   + X[:, 0] if t == 0
+    #         = A[:, t] * Y[:, t-1] + X[:, t] otherwise
+
+    @staticmethod
+    def forward(ctx, A, X, Y_init):
+        ctx.A = A.unsqueeze(-1).clone()
+        ctx.Y_init = Y_init[:, None].clone()
+        ctx.A_star = ctx.A.clone()
+        ctx.X_star = X.clone()
+        PScan.expand_(ctx.A_star, ctx.X_star)
+        return ctx.A_star * ctx.Y_init + ctx.X_star
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        U = grad_output * ctx.A_star
+        A = ctx.A.clone()
+        R = grad_output.clone()
+        PScan.acc_rev_(A, R)
+        Q = ctx.Y_init.expand_as(ctx.X_star).clone()
+        Q[:, 1:].mul_(ctx.A_star[:, :-1]).add_(ctx.X_star[:, :-1])
+        return (Q * R).sum(-1), R, U.sum(dim=1)
+
+
+pscan = PScan.apply
+
+######################################################################
+
+if __name__ == "__main__":
+    import time, sys
+
+    A = torch.rand(17, 12, 3)
+    X = torch.rand(17, 12, 3, 11)
+    Y_init = torch.rand(17, 3, 11)
+    Y = pscan(A, X, Y_init)
+    exit(0)
+
+    N, T, D = 2, 1047, 3
+
+    A = torch.rand(N, T, dtype=torch.float64).requires_grad_()
+    X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+    Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+
+    # Iterative implementation
+
+    y = Y_init
+    s = 0
+
+    for k in range(A.size(1)):
+        y = A[:, k, None] * y + X[:, k]
+        s = s + y
+
+    s = s.sum()
+
+    gA_ref, gX_ref, gY_init_ref = torch.autograd.grad(
+        s, (A, X, Y_init), retain_graph=True
+    )
+
+    # parallel scan
+
+    start_time = time.perf_counter()
+    for _ in range(1000):
+        Y = pscan(A, X, Y_init)
+    duration = time.perf_counter() - start_time
+    print(f"duration {duration}")
+
+    s = Y.sum()
+
+    gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True)
+
+    # print(gA)
+    # print(gX)
+    # print(gY_init)
+
+    print((gA - gA_ref).norm())
+    print((gX - gX_ref).norm())
+    print((gY_init - gY_init_ref).norm())
+
+    Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init)
+    Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1])
+
+    print((Y - torch.cat([Y1, Y2], dim=1)).norm())
diff --git a/qmlp.py b/qmlp.py
new file mode 100755 (executable)
index 0000000..abebfc1
--- /dev/null
+++ b/qmlp.py
@@ -0,0 +1,378 @@
+#!/usr/bin/env python
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_SEND: *.py *.sh
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+nb_quantization_levels = 101
+
+
+def quantize(x, xmin, xmax):
+    return (
+        ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
+        .long()
+        .clamp(min=0, max=nb_quantization_levels - 1)
+    )
+
+
+def dequantize(q, xmin, xmax):
+    return q / nb_quantization_levels * (xmax - xmin) + xmin
+
+
+######################################################################
+
+
+def generate_sets_and_params(
+    batch_nb_mlps,
+    nb_samples,
+    batch_size,
+    nb_epochs,
+    device=torch.device("cpu"),
+    print_log=False,
+    save_as_examples=False,
+):
+    data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
+    data_targets = torch.zeros(
+        batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
+    )
+
+    nb_rec = 8
+    nb_values = 2  # more increases the min-max gap
+
+    rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
+
+    while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
+        i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
+        nb = i.sum()
+        support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
+        support = support.sort(-1).values
+        support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
+
+        x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
+        y = (
+            (
+                (x[:, None, :, 0] >= support[:, :, None, 0]).long()
+                * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
+                * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
+                * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
+            )
+            .max(dim=1)
+            .values
+        )
+
+        data_input[i], data_targets[i], rec_support[i] = x, y, support
+
+    train_input, train_targets = (
+        data_input[:, :nb_samples],
+        data_targets[:, :nb_samples],
+    )
+    test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
+
+    q_train_input = quantize(train_input, -1, 1)
+    train_input = dequantize(q_train_input, -1, 1)
+
+    q_test_input = quantize(test_input, -1, 1)
+    test_input = dequantize(q_test_input, -1, 1)
+
+    if save_as_examples:
+        a = (
+            2
+            * torch.arange(nb_quantization_levels).float()
+            / (nb_quantization_levels - 1)
+            - 1
+        )
+        xf = torch.cat(
+            [
+                a[:, None, None].expand(
+                    nb_quantization_levels, nb_quantization_levels, 1
+                ),
+                a[None, :, None].expand(
+                    nb_quantization_levels, nb_quantization_levels, 1
+                ),
+            ],
+            2,
+        )
+        xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
+        print(f"{xf.size()=} {x.size()=}")
+        yf = (
+            (
+                (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
+                * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
+                * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
+                * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
+            )
+            .max(dim=1)
+            .values
+        )
+
+        full_input, full_targets = xf, yf
+
+        q_full_input = quantize(full_input, -1, 1)
+        full_input = dequantize(q_full_input, -1, 1)
+
+        for k in range(q_full_input[:10].size(0)):
+            with open(f"example_full_{k:04d}.dat", "w") as f:
+                for u, c in zip(full_input[k], full_targets[k]):
+                    f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+        for k in range(q_train_input[:10].size(0)):
+            with open(f"example_train_{k:04d}.dat", "w") as f:
+                for u, c in zip(train_input[k], train_targets[k]):
+                    f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+    hidden_dim = 32
+    w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
+    b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
+    w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
+        hidden_dim
+    )
+    b2 = torch.zeros(batch_nb_mlps, 2, device=device)
+
+    w1.requires_grad_()
+    b1.requires_grad_()
+    w2.requires_grad_()
+    b2.requires_grad_()
+    optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
+
+    criterion = nn.CrossEntropyLoss()
+    criterion.to(device)
+
+    for k in range(nb_epochs):
+        acc_train_loss = 0.0
+        nb_train_errors = 0
+
+        for input, targets in zip(
+            train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
+        ):
+            h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+            h = F.relu(h)
+            output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+            loss = F.cross_entropy(
+                output.reshape(-1, output.size(-1)), targets.reshape(-1)
+            )
+            acc_train_loss += loss.item() * input.size(0)
+
+            wta = output.argmax(-1)
+            nb_train_errors += (wta != targets).long().sum(-1)
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+        with torch.no_grad():
+            for p in [w1, b1, w2, b2]:
+                m = (
+                    torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
+                ).long()
+                pq = quantize(p, -2, 2)
+                p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
+
+        train_error = nb_train_errors / train_input.size(1)
+        acc_train_loss = acc_train_loss / train_input.size(1)
+
+        # print(f"{k=} {acc_train_loss=} {train_error=}")
+
+    acc_test_loss = 0
+    nb_test_errors = 0
+
+    for input, targets in zip(
+        test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
+    ):
+        h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+        h = F.relu(h)
+        output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+        acc_test_loss += loss.item() * input.size(0)
+
+        wta = output.argmax(-1)
+        nb_test_errors += (wta != targets).long().sum(-1)
+
+    test_error = nb_test_errors / test_input.size(1)
+    q_params = torch.cat(
+        [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
+    )
+    q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
+        batch_nb_mlps, -1
+    )
+    q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
+        batch_nb_mlps, -1
+    )
+
+    return q_train_set, q_test_set, q_params, test_error
+
+
+######################################################################
+
+
+def evaluate_q_params(
+    q_params,
+    q_set,
+    batch_size=25,
+    device=torch.device("cpu"),
+    nb_mlps_per_batch=1024,
+    save_as_examples=False,
+):
+    errors = []
+    nb_mlps = q_params.size(0)
+
+    for n in range(0, nb_mlps, nb_mlps_per_batch):
+        batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
+        batch_q_params = q_params[n : n + batch_nb_mlps]
+        batch_q_set = q_set[n : n + batch_nb_mlps]
+        hidden_dim = 32
+        w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
+        b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
+        w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
+        b2 = torch.empty(batch_nb_mlps, 2, device=device)
+
+        with torch.no_grad():
+            k = 0
+            for p in [w1, b1, w2, b2]:
+                print(f"{p.size()=}")
+                x = dequantize(
+                    batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
+                ).view(p.size())
+                p.copy_(x)
+                k += p.numel() // batch_nb_mlps
+
+        batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
+        data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
+        data_targets = batch_q_set[:, :, 2].to(device)
+
+        print(f"{data_input.size()=} {data_targets.size()=}")
+
+        criterion = nn.CrossEntropyLoss()
+        criterion.to(device)
+
+        acc_loss = 0.0
+        nb_errors = 0
+
+        for input, targets in zip(
+            data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
+        ):
+            h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+            h = F.relu(h)
+            output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+            loss = F.cross_entropy(
+                output.reshape(-1, output.size(-1)), targets.reshape(-1)
+            )
+            acc_loss += loss.item() * input.size(0)
+            wta = output.argmax(-1)
+            nb_errors += (wta != targets).long().sum(-1)
+
+        errors.append(nb_errors / data_input.size(1))
+        acc_loss = acc_loss / data_input.size(1)
+
+    return torch.cat(errors)
+
+
+######################################################################
+
+
+def generate_sequence_and_test_set(
+    nb_mlps,
+    nb_samples,
+    batch_size,
+    nb_epochs,
+    device,
+    nb_mlps_per_batch=1024,
+):
+    seqs, q_test_sets, test_errors = [], [], []
+
+    for n in range(0, nb_mlps, nb_mlps_per_batch):
+        q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
+            batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
+            nb_samples=nb_samples,
+            batch_size=batch_size,
+            nb_epochs=nb_epochs,
+            device=device,
+        )
+
+        seqs.append(
+            torch.cat(
+                [
+                    q_train_set,
+                    q_train_set.new_full(
+                        (
+                            q_train_set.size(0),
+                            1,
+                        ),
+                        nb_quantization_levels,
+                    ),
+                    q_params,
+                ],
+                dim=-1,
+            )
+        )
+
+        q_test_sets.append(q_test_set)
+        test_errors.append(test_error)
+
+    seq = torch.cat(seqs)
+    q_test_set = torch.cat(q_test_sets)
+    test_error = torch.cat(test_errors)
+
+    return seq, q_test_set, test_error
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import time
+
+    batch_nb_mlps, nb_samples = 128, 250
+
+    generate_sets_and_params(
+        batch_nb_mlps=10,
+        nb_samples=nb_samples,
+        batch_size=25,
+        nb_epochs=100,
+        device=torch.device("cpu"),
+        print_log=False,
+        save_as_examples=True,
+    )
+
+    exit(0)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    start_time = time.perf_counter()
+
+    data = []
+
+    seq, q_test_set, test_error = generate_sequence_and_test_set(
+        nb_mlps=batch_nb_mlps,
+        nb_samples=nb_samples,
+        device=device,
+        batch_size=25,
+        nb_epochs=250,
+        nb_mlps_per_batch=17,
+    )
+
+    end_time = time.perf_counter()
+    print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
+
+    q_train_set = seq[:, : nb_samples * 3]
+    q_params = seq[:, nb_samples * 3 + 1 :]
+    print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
+    error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
+    print(f"train {error_train*100}%")
+    error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
+    print(f"test {error_test*100}%")
diff --git a/rpl.py b/rpl.py
new file mode 100755 (executable)
index 0000000..b848afa
--- /dev/null
+++ b/rpl.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+def rpl_exec(program, stack):
+    stack = stack.copy()
+    for op in program:
+        if op == "add":
+            if len(stack) > 1:
+                a, b = stack.pop(), stack.pop()
+                stack.append(a + b)
+        elif op == "min":
+            if len(stack) > 1:
+                a, b = stack.pop(), stack.pop()
+                stack.append(min(a, b))
+        elif op == "max":
+            if len(stack) > 1:
+                a, b = stack.pop(), stack.pop()
+                stack.append(max(a, b))
+        elif op == "swp":
+            if len(stack) > 1:
+                a, b = stack.pop(), stack.pop()
+                stack.append(a)
+                stack.append(b)
+        elif op == "rep":
+            if len(stack) > 1:
+                a, b = stack.pop(), stack.pop()
+                stack += [b] * a
+        elif op == "dup":
+            if len(stack) > 0:
+                a = stack.pop()
+                stack.append(a)
+                stack.append(a)
+        elif op == "del":
+            if len(stack) > 0:
+                a = stack.pop()
+        else:
+            raise ValueError(f"Unknown instruction {op}")
+
+    return stack
+
+
+rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"]
+
+######################################################################
+
+
+def generate(
+    nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5
+):
+    prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item()
+
+    while True:
+        no_empty_stack = True
+        prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))]
+
+        result = []
+        for _ in range(nb_runs):
+            stack = [
+                x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))
+            ]
+            result_stack = rpl_exec(prog, stack)
+            if len(result_stack) == 0:
+                no_empty_stack = False
+            result = result + ["<in>"] + stack + ["<out>"] + result_stack
+
+        result = result + ["<prg>"] + prog
+        result = result + ["<end>"]
+
+        if no_empty_stack and (
+            nb_result_values_max is None or len(result_stack) <= nb_result_values_max
+        ):
+            break
+
+    return result
+
+
+def next_marker(seq, tokens, start=0):
+    pos = None
+    for t in tokens:
+        try:
+            i = seq.index(t, start)
+            if pos is None or i < pos:
+                pos = i
+        except ValueError:
+            pass
+    return pos
+
+
+def decompose(seq):
+    io = []
+    k = 0
+    while seq[k] == "<in>":
+        o = next_marker(seq, ["<out>"], start=k + 1)
+        if o is None:
+            raise ValueError("Missing output markers (should be correct in the prompt)")
+        e = next_marker(seq, ["<in>", "<prg>"], start=o)
+        if e is None:
+            raise ValueError(
+                "Missing input/output markers (should be correct in the prompt)"
+            )
+        try:
+            io.append(
+                ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]])
+            )
+        except ValueError:
+            raise ValueError(
+                "Invalid input/output value (should be correct in the prompt)"
+            )
+
+        k = e
+
+    if seq[k] == "<prg>":
+        e = next_marker(seq, ["<end>"], start=k)
+        if e is None:
+            prog = []
+        else:
+            prog = seq[k + 1 : e]
+    else:
+        raise ValueError("Missing <prg> (it should be in the prompt)")
+
+    return prog, io
+
+
+def stack_distance(target_stack, result_stack):
+    return abs(len(result_stack) - len(target_stack)) + sum(
+        [0 if x == y else 1 for x, y in zip(result_stack, target_stack)]
+    )
+
+
+def compute_nb_errors(seq):
+    prog, io = decompose(seq)
+
+    nb_total, nb_errors = 0, 0
+
+    stacks = []
+
+    if len(set(prog) - set(rpl_ops)) > 0:
+        # Program is not valid, we count 100% error
+        for start_stack, target_stack in io:
+            stacks.append((start_stack, target_stack, ["N/A"], False))
+            nb_total += len(target_stack)
+            nb_errors += len(target_stack)
+
+    else:
+        # Program is valid
+        for start_stack, target_stack in io:
+            result_stack = rpl_exec(prog, start_stack)
+            nb_total += len(target_stack)
+            e = stack_distance(target_stack, result_stack)
+            nb_errors += e
+            stacks.append((start_stack, target_stack, result_stack, e == 0))
+
+    return nb_total, nb_errors, prog, stacks
+
+
+######################################################################
+
+if __name__ == "__main__":
+    seq = generate()
+    print(seq)
+    seq[3] = 7
+    print(seq)
+    print(compute_nb_errors(seq))
diff --git a/snake.py b/snake.py
new file mode 100755 (executable)
index 0000000..8a16f9f
--- /dev/null
+++ b/snake.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+import torch.nn.functional as F
+
+
+def generate_sequences(
+    nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
+):
+    worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+    world_prior_visits = torch.zeros(nb, height, width, device=device)
+
+    # nb x 2
+    snake_position = torch.cat(
+        (
+            torch.randint(height, (nb, 1), device=device),
+            torch.randint(width, (nb, 1), device=device),
+        ),
+        1,
+    )
+    snake_direction = torch.randint(4, (nb,), device=device)
+    sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+    sequences_prior_visits = torch.zeros(
+        nb, 2 * length, device=device, dtype=torch.int64
+    )
+    i = torch.arange(nb, device=device)  # [:,None]
+
+    for l in range(length):
+        # nb x 3
+        snake_next_direction = torch.cat(
+            (
+                (snake_direction[:, None] - 1) % 4,
+                snake_direction[:, None],
+                (snake_direction[:, None] + 1) % 4,
+            ),
+            1,
+        )
+
+        # nb x 3
+        vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
+        vw = snake_next_direction % 2 * (snake_next_direction - 2)
+
+        # nb x 3 x 2
+        snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
+        snake_next_position = snake_position[:, None, :] + snake_next_speed
+
+        # nb x 3
+        val = torch.logical_and(
+            torch.logical_and(
+                snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
+            ),
+            torch.logical_and(
+                snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
+            ),
+        ).float()
+        val = (
+            # The multiplicative factors bias toward moving forward
+            torch.rand_like(val)
+            * val
+            * torch.tensor([[1.0, 2.0, 1.0]], device=device)
+        )
+
+        # nb
+        j = val.argmax(1)
+        snake_direction = snake_next_direction[i, j]
+
+        sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+        sequences_prior_visits[:, 2 * l] = world_prior_visits[
+            i, snake_position[:, 0], snake_position[:, 1]
+        ]
+        if l < prompt_length:
+            world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+        sequences[:, 2 * l + 1] = snake_direction
+
+        # nb x 2
+        snake_position = snake_next_position[i, j]
+
+    return sequences, sequences_prior_visits, worlds, world_prior_visits
+
+
+# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
+# exit(0)
+
+
+def solver(input, ar_mask):
+    for n in range(input.size(0)):
+        i, j, memory = 0, 0, {}
+        # print(input[n])
+        # print(ar_mask[n])
+        for l in range(input.size(1) // 2):
+            if ar_mask[n, 2 * l] == 1:
+                if memory.get((i, j)) is None:
+                    input[n, 2 * l] = -1
+                else:
+                    input[n, 2 * l] = memory[(i, j)]
+            else:
+                # print(f'@3 {memory=}')
+                if memory.get((i, j)) is None:
+                    memory[(i, j)] = input[n, 2 * l]
+                else:
+                    assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
+            # print(f'@1 {i=} {j=}')
+            d = input[n, 2 * l + 1].item()
+            i += (d + 1) % 2 * (d - 1)
+            j += d % 2 * (d - 2)
+            # print(f'@2 {i=} {j=}')
+
+
+def seq2str(seq):
+    return "".join(["NESW123456789"[i] for i in seq])
+
+
+######################################################################
+
+if __name__ == "__main__":
+    train_input, train_prior_visits, _, _ = generate_sequences(
+        nb=20,
+        height=9,
+        width=12,
+        nb_colors=5,
+        length=50,
+        prompt_length=100,
+    )
+
+    print([seq2str(s) for s in train_input])
+
+######################################################################
diff --git a/stack.py b/stack.py
new file mode 100755 (executable)
index 0000000..543f04e
--- /dev/null
+++ b/stack.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
+# CODE_VAL=val + 2 * nb_stacks
+
+
+def generate_sequences(
+    nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
+):
+    stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
+    stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
+    k = torch.arange(nb)
+    result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
+    recorded_stack_counts = torch.zeros(
+        nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
+    )
+
+    for t in range(nb_steps):
+        op = torch.randint(2, (nb,))
+        st = torch.randint(nb_stacks, (nb,))
+        op = op * (stack_counts[k, st] > 0)
+        if values is None:
+            val_push = torch.randint(10**nb_digits, (nb,))
+        else:
+            val_push = values[torch.randint(values.size(0), (nb,))]
+        val_pop = stack[
+            k,
+            st,
+            (stack_counts[k, st] - 1).clamp(min=0),
+        ]
+        stack[k, st, stack_counts[k, st]] = val_push
+        recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+        stack_counts[k[op == 0], st[op == 0]] += 1
+        stack_counts[k[op == 1], st[op == 1]] -= 1
+        result[:, (1 + nb_digits) * t] = st * 2 + op
+        for d in range(nb_digits):
+            result[:, (1 + nb_digits) * t + 1 + d] = (
+                (op * val_pop + (1 - op) * val_push) // (10**d)
+            ) % 10 + 2 * nb_stacks
+
+    return result.to(device), recorded_stack_counts.to(device)
+
+
+def remove_popped_values(seq, nb_stacks, nb_digits):
+    m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
+    for d in range(nb_digits):
+        k = d + 1
+        seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
+
+
+def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
+    assert seq.size(0) % (1 + nb_digits) == 0
+    s = ""
+    for t in range(seq.size(0) // (1 + nb_digits)):
+        n_op = seq[(1 + nb_digits) * t]
+        if t > 0:
+            s += " "
+        if recorded_stack_counts is not None:
+            s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
+        s += f"POP" if n_op % 2 == 1 else f"PSH"
+        if nb_stacks > 1:
+            s += f"_{n_op//2}"
+        for d in range(nb_digits):
+            if seq[(1 + nb_digits) * t + 1 + d] == -1:
+                s += " ?"
+            else:
+                s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
+    return s
+
+
+######################################################################
+
+if __name__ == "__main__":
+    nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
+    seq, recorded_stack_counts = generate_sequences(
+        nb=nb,
+        nb_steps=nb_steps,
+        nb_stacks=nb_stacks,
+        nb_digits=nb_digits,
+    )
+
+    for n in range(min(10, seq.size(0))):
+        print(
+            seq_to_str(
+                seq[n],
+                nb_stacks=nb_stacks,
+                nb_digits=nb_digits,
+                recorded_stack_counts=recorded_stack_counts[n],
+            )
+        )
+        # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
+
+    print("-- PREPARED FOR TEST -----------------")
+
+    remove_popped_values(seq, nb_stacks, nb_digits)
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
diff --git a/tasks.py b/tasks.py
new file mode 100755 (executable)
index 0000000..58638ed
--- /dev/null
+++ b/tasks.py
@@ -0,0 +1,1663 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, os, tqdm
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+from mygpt import BracketedSequence
+
+# from graph import save_attention_image
+save_attention_image = None
+
+######################################################################
+
+
+def masked_inplace_autoregression(
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    deterministic_synthesis,
+    forbidden_tokens=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
+):
+    assert input.size() == ar_mask.size()
+
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+
+    if progress_bar_desc is not None:
+        batches = tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=(input.size(0) + batch_size - 1) // batch_size,
+        )
+
+    with torch.autograd.no_grad():
+        t = model.training
+        model.eval()
+
+        for input, ar_mask in batches:
+            model.masked_inplace_autoregression(
+                input, ar_mask, forbidden_tokens, deterministic_synthesis
+            )
+
+        model.train(t)
+
+
+######################################################################
+
+
+class Task:
+    def batches(self, split="train"):
+        pass
+
+    def vocabulary_size(self):
+        pass
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        pass
+
+
+####################
+
+import problems
+
+
+class SandBox(Task):
+    def __init__(
+        self,
+        problem,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        logger=None,
+        device=torch.device("cpu"),
+        max_nb_codes=1024,
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.device = device
+        self.problem = problem
+
+        self.train_input, self.train_ar_mask = self.problem.generate_sequences(
+            nb_train_samples
+        )
+        self.test_input, self.test_ar_mask = self.problem.generate_sequences(
+            nb_test_samples
+        )
+
+        self.train_input, self.train_ar_mask = self.train_input.to(
+            device
+        ), self.train_ar_mask.to(device)
+        self.test_input, self.test_ar_mask = self.test_input.to(
+            device
+        ), self.test_ar_mask.to(device)
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+        # A bit of paranoia never hurts
+        assert self.nb_codes <= max_nb_codes
+        assert self.train_input.min() >= 0
+        assert self.test_input.min() >= 0
+        assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
+            (0,),
+            (1,),
+            (0, 1),
+        }
+        assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
+            (0,),
+            (1,),
+            (0, 1),
+        }
+
+        if logger is not None:
+            for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
+                logger(f"train_sequences {self.problem.seq2str(s)}")
+                a = "".join(["01"[x.item()] for x in a])
+                logger(f"                {a}")
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
+    ):
+        def compute_accuracy(input, ar_mask, logger=None):
+            input, ar_mask = input[:nmax], ar_mask[:nmax]
+            result = input.clone() * (1 - ar_mask)
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            log_ground_truth = ar_mask.min() == 0
+
+            if logger is not None:
+                for sp, st in zip(result[:10], input[:10]):
+                    logger(
+                        f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
+                    )
+                    if log_ground_truth:
+                        logger(
+                            f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
+                        )
+
+            nb_total, nb_correct = self.problem.compute_nb_correct(
+                input, ar_mask, result
+            )
+
+            # nb_total = ar_mask.sum().item()
+            # nb_correct = ((result == input).long() * ar_mask).sum().item()
+
+            return nb_total, nb_correct
+
+        train_nb_total, train_nb_correct = compute_accuracy(
+            self.train_input, self.train_ar_mask
+        )
+
+        logger(
+            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+        )
+
+        test_nb_total, test_nb_correct = compute_accuracy(
+            self.test_input, self.test_ar_mask, logger
+        )
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+        if save_attention_image is not None:
+            for k in range(10):
+                ns = torch.randint(self.test_input.size(0), (1,)).item()
+                input = self.test_input[ns : ns + 1].clone()
+
+                with torch.autograd.no_grad():
+                    t = model.training
+                    model.eval()
+                    # model.record_attention(True)
+                    model(BracketedSequence(input))
+                    model.train(t)
+                    # ram = model.retrieve_attention()
+                    # model.record_attention(False)
+
+                # tokens_output = [c for c in self.problem.seq2str(input[0])]
+                # tokens_input = ["n/a"] + tokens_output[:-1]
+                # for n_head in range(ram[0].size(1)):
+                # filename = os.path.join(
+                # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
+                # )
+                # attention_matrices = [m[0, n_head] for m in ram]
+                # save_attention_image(
+                # filename,
+                # tokens_input,
+                # tokens_output,
+                # attention_matrices,
+                # k_top=10,
+                ##min_total_attention=0.9,
+                # token_gap=12,
+                # layer_gap=50,
+                # )
+                # logger(f"wrote {filename}")
+
+
+######################################################################
+
+import picoclvr
+
+
+class PicoCLVR(Task):
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    # Make a list of strings from a tensor
+    def detensorize(self, x):
+        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+    # trim all the tensors in the tuple z to remove as much token from
+    # left and right in the first tensor. If z is a tuple, all its
+    # elements are trimed according to the triming for the first
+    def trim(self, z, token="<nul>"):
+        n = self.token2id[token]
+        if type(z) == tuple:
+            x = z[0]
+            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return tuple([t[:, a:b] for t in z])
+        else:
+            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return z[:, a:b]
+
+    ######################
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        nb_colors=5,
+        logger=None,
+        device=torch.device("cpu"),
+        pruner_train=None,
+        pruner_eval=None,
+    ):
+        super().__init__()
+
+        def generate_descr(nb, cache_suffix, pruner):
+            return picoclvr.generate(
+                nb,
+                height=self.height,
+                width=self.width,
+                nb_colors=nb_colors,
+                pruner=pruner,
+            )
+
+        self.height = height
+        self.width = width
+        self.batch_size = batch_size
+        self.device = device
+        self.pruner_train = pruner_train
+        self.pruner_eval = pruner_eval
+
+        if logger is not None:
+            logger(
+                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+            )
+
+        self.train_descr = generate_descr(
+            nb_train_samples, "train", pruner=self.pruner_train
+        )
+        self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
+
+        # Build the tokenizer
+        tokens = {"<nul>", "<img>"}
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        # make this set a sorted list to get the same tensors given
+        # the same descr
+        tokens = list(tokens)
+        tokens.sort()
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+        self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
+
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield self.trim(batch)
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def compute_missing_properties(
+        self, n_epoch, model, logger, deterministic_synthesis, pruner=None
+    ):
+        acc_nb_requested_properties = []
+        acc_nb_missing_properties = []
+        acc_nb_results = 0
+
+        for input in tqdm.tqdm(
+            self.test_input.split(self.batch_size),
+            dynamic_ncols=True,
+            desc=f"test-properties",
+        ):
+            result = input.clone()
+            ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
+            result = (1 - ar_mask) * result + ar_mask * self.t_nul
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            result_descr = self.detensorize(result)
+            np = picoclvr.nb_properties(
+                result_descr,
+                height=self.height,
+                width=self.width,
+                pruner=pruner,
+            )
+            nb_requested_properties, _, nb_missing_properties = zip(*np)
+            acc_nb_requested_properties += nb_requested_properties
+            acc_nb_missing_properties += nb_missing_properties
+            acc_nb_results += len(result_descr)
+
+        nb_requested_properties = sum(acc_nb_requested_properties)
+        nb_missing_properties = sum(acc_nb_missing_properties)
+
+        prefix = "" if pruner is None else "pruned_"
+        logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+        logger(
+            f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+        )
+        logger(
+            f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+        )
+
+        logger(
+            f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
+        )
+
+    ######################################################################
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
+
+        if self.pruner_eval is not None:
+            self.compute_missing_properties(n_epoch, model, self.pruner_eval)
+
+        nb_tokens_to_generate = self.height * self.width + 3
+        result_descr = []
+        nb_per_primer = 8
+        primer = []
+
+        for primer_descr in [
+            "red above green <sep> green top <sep> blue right of red",
+            "there is red <sep> there is yellow <sep> there is blue",
+            "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
+            "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
+        ]:
+            primer += [primer_descr + " <img>"] * nb_per_primer
+
+        result = self.tensorize(primer)
+        fill = result.new_full(
+            result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
+        )
+        result = torch.cat((result, fill), 1)
+        ar_mask = (result == self.t_nul).long()
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+        result_descr = self.detensorize(result)
+
+        np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+        acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
+        acc_nb_results = len(result_descr)
+
+        nb_requested_properties = sum(acc_nb_requested_properties)
+        nb_missing_properties = sum(acc_nb_missing_properties)
+
+        prefix = "demo_"
+        logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
+        logger(
+            f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
+        )
+        logger(
+            f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
+        )
+
+        img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
+
+        if img.dim() == 5:
+            if img.size(1) == 1:
+                img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
+            else:
+                img = torch.cat(
+                    [
+                        torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
+                        for x in img
+                    ],
+                    0,
+                )
+
+        image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
+        torchvision.utils.save_image(
+            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
+        )
+        logger(f"wrote {image_name}")
+
+
+######################################################################
+
+
+class MNIST(Task):
+    def __init__(
+        self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+    ):
+        super().__init__()
+
+        self.nb_train_samples = (nb_train_samples,)
+        self.nb_test_samples = (nb_test_samples,)
+        self.batch_size = batch_size
+        self.device = device
+        data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
+        self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
+        data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
+        self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return 256
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
+        ar_mask = torch.full_like(results, 1)
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            results,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+        image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
+        torchvision.utils.save_image(
+            1 - results.reshape(-1, 1, 28, 28) / 255.0,
+            image_name,
+            nrow=16,
+            pad_value=0.8,
+        )
+        logger(f"wrote {image_name}")
+
+
+######################################################################
+
+import maze
+
+
+class Maze(Task):
+    def map2seq(self, *m):
+        return torch.cat([x.flatten(1) for x in m], 1)
+
+    def seq2map(self, s):
+        s = s.reshape(s.size(0), -1, self.height, self.width)
+        return (s[:, k] for k in range(s.size(1)))
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        nb_walls,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.height = height
+        self.width = width
+        self.device = device
+
+        train_mazes, train_paths, _ = maze.create_maze_data(
+            nb_train_samples,
+            height=height,
+            width=width,
+            nb_walls=nb_walls,
+            progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
+        )
+        self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
+
+        test_mazes, test_paths, _ = maze.create_maze_data(
+            nb_test_samples,
+            height=height,
+            width=width,
+            nb_walls=nb_walls,
+            progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
+        )
+        self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def compute_error(
+        self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
+    ):
+        nb_total, nb_correct = 0, 0
+        count = torch.zeros(
+            self.width * self.height,
+            self.width * self.height,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        for input in self.batches(split, nb_to_use):
+            result = input.clone()
+            ar_mask = result.new_zeros(result.size())
+            ar_mask[:, self.height * self.width :] = 1
+            result *= 1 - ar_mask
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+            mazes, paths = self.seq2map(result)
+            path_correctness = maze.path_correctness(mazes, paths)
+            nb_correct += path_correctness.long().sum()
+            nb_total += mazes.size(0)
+
+            optimal_path_lengths = (
+                (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            predicted_path_lengths = (
+                (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            optimal_path_lengths = optimal_path_lengths[path_correctness]
+            predicted_path_lengths = predicted_path_lengths[path_correctness]
+            count[optimal_path_lengths, predicted_path_lengths] += 1
+
+        if count.max() == 0:
+            count = None
+        else:
+            count = count[
+                : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
+            ]
+
+        return nb_total, nb_correct, count
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        train_nb_total, train_nb_correct, count = self.compute_error(
+            model,
+            "train",
+            nb_to_use=1000,
+            deterministic_synthesis=deterministic_synthesis,
+        )
+        logger(
+            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+        )
+
+        test_nb_total, test_nb_correct, count = self.compute_error(
+            model,
+            "test",
+            nb_to_use=1000,
+            deterministic_synthesis=deterministic_synthesis,
+        )
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+        if count is not None:
+            proportion_optimal = count.diagonal().sum().float() / count.sum()
+            logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
+            with open(
+                os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
+            ) as f:
+                for i in range(count.size(0)):
+                    for j in range(count.size(1)):
+                        eol = " " if j < count.size(1) - 1 else "\n"
+                        f.write(f"{count[i,j]}{eol}")
+
+        input = self.test_input[:48]
+        result = input.clone()
+        ar_mask = result.new_zeros(result.size())
+        ar_mask[:, self.height * self.width :] = 1
+        result *= 1 - ar_mask
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        mazes, paths = self.seq2map(input)
+        _, predicted_paths = self.seq2map(result)
+
+        filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
+        maze.save_image(
+            filename,
+            mazes=mazes,
+            target_paths=paths,
+            predicted_paths=predicted_paths,
+            path_correct=maze.path_correctness(mazes, predicted_paths),
+            path_optimal=maze.path_optimality(paths, predicted_paths),
+        )
+        logger(f"wrote {filename}")
+
+
+######################################################################
+
+
+import snake
+
+
+class Snake(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        nb_colors,
+        length,
+        prompt_length,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.height = height
+        self.width = width
+        self.device = device
+        self.prompt_length = prompt_length
+
+        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
+            nb_train_samples,
+            height,
+            width,
+            nb_colors,
+            length,
+            prompt_length,
+            self.device,
+        )
+        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
+            nb_test_samples,
+            height,
+            width,
+            nb_colors,
+            length,
+            prompt_length,
+            self.device,
+        )
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        def compute_nb_correct(input, prior_visits):
+            result = input.clone()
+            i = torch.arange(result.size(1), device=result.device)[None, :]
+            ar_mask = (
+                torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
+                .long()
+                .expand_as(result)
+            )
+            result *= 1 - ar_mask
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            nb_total = ((prior_visits > 0) * ar_mask).sum()
+
+            nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
+
+            return nb_total, nb_correct
+
+        test_nb_total, test_nb_correct = compute_nb_correct(
+            self.test_input[:1000], self.test_prior_visits[:1000]
+        )
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+
+######################################################################
+
+
+import stack
+
+
+class Stack(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        logger,
+        nb_steps,
+        nb_stacks,
+        nb_digits,
+        fraction_values_for_train=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.nb_steps = nb_steps
+        self.nb_stacks = nb_stacks
+        self.nb_digits = nb_digits
+        self.device = device
+
+        if fraction_values_for_train is None:
+            values_for_train = None
+            values_for_test = None
+        else:
+            all = torch.randperm(10**nb_digits)
+            nb_for_train = int(all.size(0) * fraction_values_for_train)
+            values_for_train = all[:nb_for_train]
+            values_for_test = all[nb_for_train:]
+
+        self.train_input, self.train_stack_counts = stack.generate_sequences(
+            nb_train_samples,
+            nb_steps,
+            nb_stacks,
+            nb_digits,
+            values_for_train,
+            self.device,
+        )
+
+        self.test_input, self.test_stack_counts = stack.generate_sequences(
+            nb_test_samples,
+            nb_steps,
+            nb_stacks,
+            nb_digits,
+            values_for_test,
+            self.device,
+        )
+
+        i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
+        counts = self.test_stack_counts.flatten()[i.flatten()]
+        counts = F.one_hot(counts).sum(0)
+        logger(f"test_pop_stack_counts {counts}")
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        def compute_nb_correct(input):
+            result = input.clone()
+            stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+            ar_mask = (result != input).long()
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            errors = ((result != input).long() * ar_mask).reshape(
+                -1, 1 + self.nb_digits
+            )
+            ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
+
+            nb_total = ar_mask.max(1).values.sum()
+            nb_correct = nb_total - errors.max(1).values.sum()
+
+            return nb_total, nb_correct
+
+        test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+        ##############################################################
+        # Log a few generated sequences
+        input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
+        result = input.clone()
+        stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+        ar_mask = (result != input).long()
+
+        # for n in range(result.size(0)):
+        # logger(
+        # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+        # )
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        for n in range(result.size(0)):
+            logger(
+                f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+            )
+        ##############################################################
+
+
+######################################################################
+
+import rpl
+
+
+class RPL(Task):
+    def tensorize(self, sequences):
+        len_max = max([len(x) for x in sequences])
+        return torch.cat(
+            [
+                torch.tensor(
+                    [
+                        [
+                            self.token2id[str(c)]
+                            for c in s + ["<nul>"] * (len_max - len(s))
+                        ]
+                        for s in sequences
+                    ]
+                )
+            ],
+            0,
+        )
+
+    def seq2str(self, seq):
+        return " ".join([self.id2token[i] for i in seq])
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        nb_starting_values=3,
+        max_input=9,
+        prog_len=6,
+        nb_runs=5,
+        no_prog=False,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.device = device
+        self.no_prog = no_prog
+
+        train_sequences = [
+            rpl.generate(
+                nb_starting_values=nb_starting_values,
+                nb_result_values_max=4 * nb_starting_values,
+                max_input=max_input,
+                prog_len=prog_len,
+                nb_runs=nb_runs,
+            )
+            for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
+        ]
+
+        test_sequences = [
+            rpl.generate(
+                nb_starting_values=nb_starting_values,
+                nb_result_values_max=4 * nb_starting_values,
+                max_input=max_input,
+                prog_len=prog_len,
+                nb_runs=nb_runs,
+            )
+            for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
+        ]
+
+        symbols = list(
+            set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
+        )
+        val_max = max([x if type(x) is int else 0 for x in symbols])
+        symbols = list(filter(lambda x: type(x) is str, symbols))
+        symbols.sort()
+        symbols += [str(n) for n in range(val_max + 1)]
+        self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
+        self.id2token = dict([(n, c) for c, n in self.token2id.items()])
+
+        self.t_nul = self.token2id["<nul>"]
+        self.t_input = self.token2id["<in>"]
+        self.t_output = self.token2id["<out>"]
+        self.t_prog = self.token2id["<prg>"]
+        self.t_end = self.token2id["<end>"]
+
+        self.train_input = self.tensorize(train_sequences)
+        self.test_input = self.tensorize(test_sequences)
+
+        if no_prog:
+            # Excise the program from every train and test example
+            k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
+                None, :
+            ]
+            p = (
+                ((self.train_input == self.t_prog).long() * k)
+                .max(1, keepdim=True)
+                .values
+            )
+            self.train_input = (
+                self.train_input * (k <= p).long()
+                + self.t_end * (k == p + 1).long()
+                + self.t_nul * (k > p + 1).long()
+            )
+            k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
+                None, :
+            ]
+            p = (
+                ((self.test_input == self.t_prog).long() * k)
+                .max(1, keepdim=True)
+                .values
+            )
+            self.test_input = (
+                self.test_input * (k <= p).long()
+                + self.t_end * (k == p + 1).long()
+                + self.t_nul * (k > p + 1).long()
+            )
+
+        if logger is not None:
+            logger(f"value_max {val_max}")
+            for x in self.train_input[:25]:
+                end = (x != self.t_nul).nonzero().max().item() + 1
+                seq = [self.id2token[i.item()] for i in x[:end]]
+                s = " ".join(seq)
+                logger(f"example_seq {s}")
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
+            batch = batch[:, :last].to(self.device)
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        # --------------------------------------------------------------------
+        def compute_nb_errors_prog(input, nb_to_log=0):
+            result = input.clone()
+            s = (result == self.t_prog).long()
+            ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+            result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            sum_nb_total, sum_nb_errors = 0, 0
+            for one_input, one_result in zip(input, result):
+                seq = [self.id2token[i.item()] for i in one_result]
+                nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
+                sum_nb_total += 1
+                sum_nb_errors += 0 if nb_errors == 0 else 1
+                if nb_to_log > 0:
+                    gt_seq = [self.id2token[i.item()] for i in one_input]
+                    _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
+                    gt_prog = " ".join([str(x) for x in gt_prog])
+                    prog = " ".join([str(x) for x in prog])
+                    comment = "*" if nb_errors == 0 else "-"
+                    logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
+                    for start_stack, target_stack, result_stack, correct in stacks:
+                        comment = "*" if correct else "-"
+                        start_stack = " ".join([str(x) for x in start_stack])
+                        target_stack = " ".join([str(x) for x in target_stack])
+                        result_stack = " ".join([str(x) for x in result_stack])
+                        logger(
+                            f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
+                        )
+                    nb_to_log -= 1
+
+            return sum_nb_total, sum_nb_errors
+
+        # --------------------------------------------------------------------
+        def compute_nb_errors_output(input, nb_to_log=0):
+            result = input.clone()
+            k = torch.arange(result.size(1), device=result.device)[None, :]
+            last_output_idx = (
+                ((result == self.t_output) * k).max(dim=1, keepdim=True).values
+            )
+            first_prog_idx = (
+                ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
+            )
+            ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
+            result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            sum_nb_total, sum_nb_errors = 0, 0
+            for one_input, one_result, i, j in zip(
+                input, result, last_output_idx, first_prog_idx
+            ):
+                seq = [self.id2token[i.item()] for i in one_result]
+                sum_nb_total += 1
+                correct = (one_input - one_result).abs().max() == 0
+                sum_nb_errors += 0 if correct else 1
+                if nb_to_log > 0:
+                    result_stack = [
+                        self.id2token[i.item()] for i in one_result[i : j + 1]
+                    ]
+                    target_stack = [
+                        self.id2token[i.item()] for i in one_input[i : j + 1]
+                    ]
+                    comment = "*" if correct else "-"
+                    result_stack = " ".join([str(x) for x in result_stack])
+                    target_stack = " ".join([str(x) for x in target_stack])
+                    logger(
+                        f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
+                    )
+                    nb_to_log -= 1
+
+            return sum_nb_total, sum_nb_errors
+
+        # --------------------------------------------------------------------
+
+        if not self.no_prog:
+            test_nb_total, test_nb_errors = compute_nb_errors_prog(
+                self.test_input[:1000].to(self.device), nb_to_log=10
+            )
+
+            logger(
+                f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
+            )
+
+            logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
+
+        test_nb_total, test_nb_errors = compute_nb_errors_output(
+            self.test_input[:1000].to(self.device), nb_to_log=10
+        )
+
+        logger(
+            f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
+        )
+
+        if save_attention_image is not None:
+            ns = torch.randint(self.test_input.size(0), (1,)).item()
+            input = self.test_input[ns : ns + 1].clone()
+            last = (input != self.t_nul).max(0).values.nonzero().max() + 3
+            input = input[:, :last].to(self.device)
+
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+                model.record_attention(True)
+                model(BracketedSequence(input))
+                model.train(t)
+                ram = model.retrieve_attention()
+                model.record_attention(False)
+
+            tokens_output = [self.id2token[i.item()] for i in input[0]]
+            tokens_input = ["n/a"] + tokens_output[:-1]
+            for n_head in range(ram[0].size(1)):
+                filename = os.path.join(
+                    result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+                )
+                attention_matrices = [m[0, n_head] for m in ram]
+                save_attention_image(
+                    filename,
+                    tokens_input,
+                    tokens_output,
+                    attention_matrices,
+                    k_top=10,
+                    # min_total_attention=0.9,
+                    token_gap=12,
+                    layer_gap=50,
+                )
+                logger(f"wrote {filename}")
+
+
+######################################################################
+
+
+import expr
+
+
+class Expr(Task):
+    def tensorize(self, sequences):
+        len_max = max([len(x) for x in sequences])
+        return torch.cat(
+            [
+                torch.tensor(
+                    [
+                        [self.char2id[c] for c in s + "#" * (len_max - len(s))]
+                        for s in sequences
+                    ]
+                )
+            ],
+            0,
+        ).to(self.device)
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        nb_variables,
+        sequence_length,
+        operand_max,
+        result_max,
+        batch_size,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.device = device
+
+        train_sequences = expr.generate_sequences(
+            nb_train_samples,
+            nb_variables=nb_variables,
+            length=sequence_length,
+            operand_max=operand_max,
+            result_max=result_max,
+        )
+
+        test_sequences = expr.generate_sequences(
+            nb_test_samples,
+            nb_variables=nb_variables,
+            length=sequence_length,
+            operand_max=operand_max,
+            result_max=result_max,
+        )
+
+        symbols = list(set("#" + "".join(train_sequences + test_sequences)))
+        symbols.sort()
+
+        self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
+        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+        self.filler, self.space = self.char2id["#"], self.char2id[" "]
+
+        self.train_input = self.tensorize(train_sequences)
+        self.test_input = self.tensorize(test_sequences)
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            last = (batch != self.filler).max(0).values.nonzero().max() + 3
+            batch = batch[:, :last]
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def seq2str(self, s):
+        return "".join([self.id2char[k.item()] for k in s])
+
+    def produce_results(
+        self,
+        n_epoch,
+        model,
+        result_dir,
+        logger,
+        deterministic_synthesis,
+        input_file=None,
+    ):
+        def compute_nb_correct(input):
+            result = input.clone()
+            s = (result == self.space).long()
+            ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+            result = (1 - ar_mask) * result + ar_mask * self.filler
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            nb_total = input.size(0)
+            nb_correct = (input == result).long().min(1).values.sum()
+
+            #######################################################################
+            # Comput predicted vs. true variable values
+
+            nb_delta = torch.zeros(5, dtype=torch.int64)
+            nb_missed = 0
+
+            values_input = expr.extract_results([self.seq2str(s) for s in input])
+            values_result = expr.extract_results([self.seq2str(s) for s in result])
+
+            filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
+
+            with open(filename, "w") as f:
+                for i, r in zip(values_input, values_result):
+                    for n, vi in i.items():
+                        vr = r.get(n)
+                        f.write(f"{vi} {-1 if vr is None else vr}\n")
+
+                        if vr is None or vr < 0:
+                            nb_missed += 1
+                        else:
+                            d = abs(vr - vi)
+                            if d >= nb_delta.size(0):
+                                nb_missed += 1
+                            else:
+                                nb_delta[d] += 1
+
+            ######################################################################
+
+            return nb_total, nb_correct, nb_delta, nb_missed
+
+        (
+            test_nb_total,
+            test_nb_correct,
+            test_nb_delta,
+            test_nb_missed,
+        ) = compute_nb_correct(self.test_input[:10000])
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+        nb_total = test_nb_delta.sum() + test_nb_missed
+        for d in range(test_nb_delta.size(0)):
+            logger(
+                f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
+            )
+        logger(
+            f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
+        )
+
+        ##############################################################
+        # Log a few generated sequences
+        if input_file is None:
+            input = self.test_input[:10]
+        else:
+            with open(input_file, "r") as f:
+                sequences = [e.strip() for e in f.readlines()]
+                sequences = [s + " " + "#" * 50 for s in sequences]
+                input = self.tensorize(sequences)
+
+        result = input.clone()
+        s = (result == self.space).long()
+        ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+        result = (1 - ar_mask) * result + ar_mask * self.filler
+
+        for n in range(result.size(0)):
+            logger(f"test_before {self.seq2str(result[n])}")
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        correct = (1 - ar_mask) * self.space + ar_mask * input
+        for n in range(result.size(0)):
+            comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
+            logger(f"test_after  {self.seq2str(result[n])} {comment}")
+            logger(f"truth       {self.seq2str(correct[n])}")
+        ##############################################################
+
+
+######################################################################
+
+import grid
+
+
+class Grid(Task):
+    # Make a tensor from a list of strings
+    def str2tensor(self, descr):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    # Make a list of strings from a tensor
+    def tensor2str(self, x):
+        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+    # trim all the tensors in the tuple z to remove as much token from
+    # left and right in the first tensor. If z is a tuple, all its
+    # elements are trimed according to the triming for the first
+    def trim(self, z, token="#"):
+        n = self.token2id[token]
+        if type(z) == tuple:
+            x = z[0]
+            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return tuple([t[:, a:b] for t in z])
+        else:
+            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return z[:, a:b]
+
+    ######################
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        size,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.device = device
+        self.batch_size = batch_size
+        self.grid_factory = grid.GridFactory(size=size)
+
+        if logger is not None:
+            logger(
+                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+            )
+
+        self.train_descr = self.grid_factory.generate_samples(
+            nb_train_samples, lambda r: tqdm.tqdm(r)
+        )
+        self.test_descr = self.grid_factory.generate_samples(
+            nb_test_samples, lambda r: tqdm.tqdm(r)
+        )
+
+        # Build the tokenizer
+        tokens = set()
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        # make this set a sorted list to get the same tensors given
+        # the same descr
+        tokens = list(tokens)
+        tokens.sort()
+        tokens = ["#"] + tokens
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+        self.t_nul = self.token2id["#"]
+        self.t_true = self.token2id["true"]
+        self.t_false = self.token2id["false"]
+
+        # Tokenize the train and test sets
+        self.train_input = self.str2tensor(self.train_descr)
+        self.test_input = self.str2tensor(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield self.trim(batch)
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        correct = self.test_input[:1000]
+        result = correct.clone()
+        ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
+
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
+            logger(f"test_before {e}")
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
+            logger(f"test_after  {e}")
+
+        logger(f"----------------------------------------------------------")
+
+        nb_total = ar_mask.sum().item()
+        nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
+######################################################################
+
+import qmlp
+
+
+class QMLP(Task):
+    ######################
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        result_dir,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.device = device
+        self.batch_size = batch_size
+        self.nb_samples_per_mlp = 256
+
+        if logger is not None:
+            logger(
+                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+            )
+
+        seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
+            nb_mlps=nb_train_samples + nb_test_samples,
+            nb_samples=self.nb_samples_per_mlp,
+            device=self.device,
+            batch_size=64,
+            nb_epochs=250,
+            nb_mlps_per_batch=1024,
+        )
+
+        self.train_input = seq[:nb_train_samples]
+        self.train_q_test_set = q_test_set[:nb_train_samples]
+        self.train_ref_test_errors = test_error[:nb_train_samples]
+        self.test_input = seq[nb_train_samples:]
+        self.test_q_test_set = q_test_set[nb_train_samples:]
+        self.test_ref_test_errors = test_error[nb_train_samples:]
+
+        filename = os.path.join(result_dir, f"train_errors_ref.dat")
+        with open(filename, "w") as f:
+            for e in self.train_ref_test_errors:
+                f.write(f"{e}\n")
+
+        filename = os.path.join(result_dir, f"test_errors_ref.dat")
+        with open(filename, "w") as f:
+            for e in self.test_ref_test_errors:
+                f.write(f"{e}\n")
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        correct = self.test_input[:1000]
+        result = correct.clone()
+        ar_mask = (
+            torch.arange(result.size(1), device=result.device)
+            > self.nb_samples_per_mlp * 3 + 1
+        ).long()[None, :]
+        ar_mask = ar_mask.expand_as(result)
+        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        q_train_set = result[:, : self.nb_samples_per_mlp * 3]
+        q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
+        error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
+
+        filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
+        with open(filename, "w") as f:
+            for e in error_test:
+                f.write(f"{e}\n")
+
+
+######################################################################
diff --git a/world.py b/world.py
new file mode 100755 (executable)
index 0000000..aad0bfb
--- /dev/null
+++ b/world.py
@@ -0,0 +1,485 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+import cairo
+
+######################################################################
+
+
+class Box:
+    nb_rgb_levels = 10
+
+    def __init__(self, x, y, w, h, r, g, b):
+        self.x = x
+        self.y = y
+        self.w = w
+        self.h = h
+        self.r = r
+        self.g = g
+        self.b = b
+
+    def collision(self, scene):
+        for c in scene:
+            if (
+                self is not c
+                and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
+                and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
+            ):
+                return True
+        return False
+
+
+######################################################################
+
+
+class Normalizer(nn.Module):
+    def __init__(self, mu, std):
+        super().__init__()
+        self.register_buffer("mu", mu)
+        self.register_buffer("log_var", 2 * torch.log(std))
+
+    def forward(self, x):
+        return (x - self.mu) / torch.exp(self.log_var / 2.0)
+
+
+class SignSTE(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        # torch.sign() takes three values
+        s = (x >= 0).float() * 2 - 1
+
+        if self.training:
+            u = torch.tanh(x)
+            return s + u - u.detach()
+        else:
+            return s
+
+
+class DiscreteSampler2d(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        s = (x >= x.max(-3, keepdim=True).values).float()
+
+        if self.training:
+            u = x.softmax(dim=-3)
+            return s + u - u.detach()
+        else:
+            return s
+
+
+def loss_H(binary_logits, h_threshold=1):
+    p = binary_logits.sigmoid().mean(0)
+    h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
+    h.clamp_(max=h_threshold)
+    return h_threshold - h.mean()
+
+
+def train_encoder(
+    train_input,
+    test_input,
+    depth,
+    nb_bits_per_token,
+    dim_hidden=48,
+    lambda_entropy=0.0,
+    lr_start=1e-3,
+    lr_end=1e-4,
+    nb_epochs=10,
+    batch_size=25,
+    logger=None,
+    device=torch.device("cpu"),
+):
+    mu, std = train_input.float().mean(), train_input.float().std()
+
+    def encoder_core(depth, dim):
+        l = [
+            [
+                nn.Conv2d(
+                    dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
+                ),
+                nn.ReLU(),
+                nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
+                nn.ReLU(),
+            ]
+            for k in range(depth)
+        ]
+
+        return nn.Sequential(*[x for m in l for x in m])
+
+    def decoder_core(depth, dim):
+        l = [
+            [
+                nn.ConvTranspose2d(
+                    dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
+                ),
+                nn.ReLU(),
+                nn.ConvTranspose2d(
+                    dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
+                ),
+                nn.ReLU(),
+            ]
+            for k in range(depth - 1, -1, -1)
+        ]
+
+        return nn.Sequential(*[x for m in l for x in m])
+
+    encoder = nn.Sequential(
+        Normalizer(mu, std),
+        nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
+        nn.ReLU(),
+        # 64x64
+        encoder_core(depth=depth, dim=dim_hidden),
+        # 8x8
+        nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
+    )
+
+    quantizer = SignSTE()
+
+    decoder = nn.Sequential(
+        nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
+        # 8x8
+        decoder_core(depth=depth, dim=dim_hidden),
+        # 64x64
+        nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
+    )
+
+    model = nn.Sequential(encoder, decoder)
+
+    nb_parameters = sum(p.numel() for p in model.parameters())
+
+    logger(f"vqae nb_parameters {nb_parameters}")
+
+    model.to(device)
+
+    for k in range(nb_epochs):
+        lr = math.exp(
+            math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
+        )
+        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+
+        acc_train_loss = 0.0
+
+        for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
+            input = input.to(device)
+            z = encoder(input)
+            zq = quantizer(z)
+            output = decoder(zq)
+
+            output = output.reshape(
+                output.size(0), -1, 3, output.size(2), output.size(3)
+            )
+
+            train_loss = F.cross_entropy(output, input)
+
+            if lambda_entropy > 0:
+                train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
+            acc_train_loss += train_loss.item() * input.size(0)
+
+            optimizer.zero_grad()
+            train_loss.backward()
+            optimizer.step()
+
+        acc_test_loss = 0.0
+
+        for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
+            input = input.to(device)
+            z = encoder(input)
+            zq = quantizer(z)
+            output = decoder(zq)
+
+            output = output.reshape(
+                output.size(0), -1, 3, output.size(2), output.size(3)
+            )
+
+            test_loss = F.cross_entropy(output, input)
+
+            acc_test_loss += test_loss.item() * input.size(0)
+
+        train_loss = acc_train_loss / train_input.size(0)
+        test_loss = acc_test_loss / test_input.size(0)
+
+        logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+        sys.stdout.flush()
+
+    return encoder, quantizer, decoder
+
+
+######################################################################
+
+
+def scene2tensor(xh, yh, scene, size):
+    width, height = size, size
+    pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
+    data = pixel_map.numpy()
+    surface = cairo.ImageSurface.create_for_data(
+        data, cairo.FORMAT_ARGB32, width, height
+    )
+
+    ctx = cairo.Context(surface)
+    ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
+
+    for b in scene:
+        ctx.move_to(b.x * size, b.y * size)
+        ctx.rel_line_to(b.w * size, 0)
+        ctx.rel_line_to(0, b.h * size)
+        ctx.rel_line_to(-b.w * size, 0)
+        ctx.close_path()
+        ctx.set_source_rgba(
+            b.r / (Box.nb_rgb_levels - 1),
+            b.g / (Box.nb_rgb_levels - 1),
+            b.b / (Box.nb_rgb_levels - 1),
+            1.0,
+        )
+        ctx.fill()
+
+    hs = size * 0.1
+    ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
+    ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
+    ctx.rel_line_to(hs, 0)
+    ctx.rel_line_to(0, hs)
+    ctx.rel_line_to(-hs, 0)
+    ctx.close_path()
+    ctx.fill()
+
+    return (
+        pixel_map[None, :, :, :3]
+        .flip(-1)
+        .permute(0, 3, 1, 2)
+        .long()
+        .mul(Box.nb_rgb_levels)
+        .floor_divide(256)
+    )
+
+
+def random_scene(nb_insert_attempts=3):
+    scene = []
+    colors = [
+        ((Box.nb_rgb_levels - 1), 0, 0),
+        (0, (Box.nb_rgb_levels - 1), 0),
+        (0, 0, (Box.nb_rgb_levels - 1)),
+        ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
+        (
+            (Box.nb_rgb_levels * 2) // 3,
+            (Box.nb_rgb_levels * 2) // 3,
+            (Box.nb_rgb_levels * 2) // 3,
+        ),
+    ]
+
+    for k in range(nb_insert_attempts):
+        wh = torch.rand(2) * 0.2 + 0.2
+        xy = torch.rand(2) * (1 - wh)
+        c = colors[torch.randint(len(colors), (1,))]
+        b = Box(
+            xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
+        )
+        if not b.collision(scene):
+            scene.append(b)
+
+    return scene
+
+
+def generate_episode(steps, size=64):
+    delta = 0.1
+    effects = [
+        (False, 0, 0),
+        (False, delta, 0),
+        (False, 0, delta),
+        (False, -delta, 0),
+        (False, 0, -delta),
+        (True, delta, 0),
+        (True, 0, delta),
+        (True, -delta, 0),
+        (True, 0, -delta),
+    ]
+
+    while True:
+        frames = []
+
+        scene = random_scene()
+        xh, yh = tuple(x.item() for x in torch.rand(2))
+
+        actions = torch.randint(len(effects), (len(steps),))
+        nb_changes = 0
+
+        for s, a in zip(steps, actions):
+            if s:
+                frames.append(scene2tensor(xh, yh, scene, size=size))
+
+            grasp, dx, dy = effects[a]
+
+            if grasp:
+                for b in scene:
+                    if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
+                        x, y = b.x, b.y
+                        b.x += dx
+                        b.y += dy
+                        if (
+                            b.x < 0
+                            or b.y < 0
+                            or b.x + b.w > 1
+                            or b.y + b.h > 1
+                            or b.collision(scene)
+                        ):
+                            b.x, b.y = x, y
+                        else:
+                            xh += dx
+                            yh += dy
+                            nb_changes += 1
+            else:
+                x, y = xh, yh
+                xh += dx
+                yh += dy
+                if xh < 0 or xh > 1 or yh < 0 or yh > 1:
+                    xh, yh = x, y
+
+        if nb_changes > len(steps) // 3:
+            break
+
+    return frames, actions
+
+
+######################################################################
+
+
+def generate_episodes(nb, steps):
+    all_frames, all_actions = [], []
+    for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
+        frames, actions = generate_episode(steps)
+        all_frames += frames
+        all_actions += [actions[None, :]]
+    return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
+
+
+def create_data_and_processors(
+    nb_train_samples,
+    nb_test_samples,
+    mode,
+    nb_steps,
+    depth=3,
+    nb_bits_per_token=8,
+    nb_epochs=10,
+    device=torch.device("cpu"),
+    device_storage=torch.device("cpu"),
+    logger=None,
+):
+    assert mode in ["first_last"]
+
+    if mode == "first_last":
+        steps = [True] + [False] * (nb_steps + 1) + [True]
+
+    if logger is None:
+        logger = lambda s: print(s)
+
+    train_input, train_actions = generate_episodes(nb_train_samples, steps)
+    train_input, train_actions = train_input.to(device_storage), train_actions.to(
+        device_storage
+    )
+    test_input, test_actions = generate_episodes(nb_test_samples, steps)
+    test_input, test_actions = test_input.to(device_storage), test_actions.to(
+        device_storage
+    )
+
+    encoder, quantizer, decoder = train_encoder(
+        train_input,
+        test_input,
+        depth=depth,
+        nb_bits_per_token=nb_bits_per_token,
+        lambda_entropy=1.0,
+        nb_epochs=nb_epochs,
+        logger=logger,
+        device=device,
+    )
+    encoder.train(False)
+    quantizer.train(False)
+    decoder.train(False)
+
+    z = encoder(train_input[:1].to(device))
+    pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
+    z_h, z_w = z.size(2), z.size(3)
+
+    logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
+
+    def frame2seq(input, batch_size=25):
+        seq = []
+        p = pow2.to(device)
+        for x in input.split(batch_size):
+            x = x.to(device)
+            z = encoder(x)
+            ze_bool = (quantizer(z) >= 0).long()
+            output = (
+                ze_bool.permute(0, 2, 3, 1).reshape(
+                    ze_bool.size(0), -1, ze_bool.size(1)
+                )
+                * p
+            ).sum(-1)
+
+            seq.append(output)
+
+        return torch.cat(seq, dim=0)
+
+    def seq2frame(input, batch_size=25, T=1e-2):
+        frames = []
+        p = pow2.to(device)
+        for seq in input.split(batch_size):
+            seq = seq.to(device)
+            zd_bool = (seq[:, :, None] // p) % 2
+            zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
+            logits = decoder(zd_bool * 2.0 - 1.0)
+            logits = logits.reshape(
+                logits.size(0), -1, 3, logits.size(2), logits.size(3)
+            ).permute(0, 2, 3, 4, 1)
+            output = torch.distributions.categorical.Categorical(
+                logits=logits / T
+            ).sample()
+
+            frames.append(output)
+
+        return torch.cat(frames, dim=0)
+
+    return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
+
+
+######################################################################
+
+if __name__ == "__main__":
+    (
+        train_input,
+        train_actions,
+        test_input,
+        test_actions,
+        frame2seq,
+        seq2frame,
+    ) = create_data_and_processors(
+        25000,
+        1000,
+        nb_epochs=5,
+        mode="first_last",
+        nb_steps=20,
+    )
+
+    input = test_input[:256]
+
+    seq = frame2seq(input)
+    output = seq2frame(seq)
+
+    torchvision.utils.save_image(
+        input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
+    )
+
+    torchvision.utils.save_image(
+        output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16
+    )