Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 1 Jul 2023 11:13:55 +0000 (13:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 1 Jul 2023 11:13:55 +0000 (13:13 +0200)
stack.py [new file with mode: 0755]

diff --git a/stack.py b/stack.py
new file mode 100755 (executable)
index 0000000..d3be4f8
--- /dev/null
+++ b/stack.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import torch, torchvision
+
+######################################################################
+
+# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
+# CODE_VAL=val + 2 * nb_stacks
+
+
+def generate(nb, seq_len, nb_stacks, nb_values):
+    stack = torch.empty(nb, nb_stacks, seq_len, dtype=torch.int64)
+    stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64)
+    k = torch.arange(nb)
+    result = torch.empty(nb, 2 * seq_len, dtype=torch.int64)
+
+    for t in range(seq_len):
+        op = torch.randint(2, (nb,))
+        st = torch.randint(nb_stacks, (nb,))
+        op = op * (stack_pointers[k, st] > 0)
+        val_push = torch.randint(nb_values, (nb,))
+        # top_val[n,s]=stack[n,stack_pointers[n,s]]
+        top_values = stack[
+            k,
+            st,
+            (stack_pointers[k, st] - 1).clamp(min=0),
+        ]
+        stack[
+            k[:, None].expand_as(stack_pointers),
+            st[:, None].expand_as(stack_pointers),
+            stack_pointers,
+        ] = val_push[:, None].expand_as(stack_pointers)
+        stack_pointers[k[op == 0], st[op == 0]] += 1
+        stack_pointers[k[op == 1], st[op == 1]] -= 1
+        result[:, 2 * t] = st * 2 + op
+        result[:, 2 * t + 1] = (op * top_values + (1 - op) * val_push) + 2 * nb_stacks
+
+    return result
+
+
+def seq_to_str(seq):
+    assert seq.size(0) % 2 == 0
+    s = ""
+    for t in range(0, seq.size(0), 2):
+        op = seq[t]
+        op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}"
+        val = seq[t + 1] - 2 * nb_stacks
+        if t > 0:
+            s += " "
+        s += f"{op} {val}"
+    return s
+
+
+######################################################################
+
+if __name__ == "__main__":
+    nb, seq_len, nb_stacks, nb_values = 3, 10, 1, 5
+    result = generate(nb=nb, seq_len=seq_len, nb_stacks=nb_stacks, nb_values=nb_values)
+    for n in range(result.size(0)):
+        print(seq_to_str(result[n]))