Update.
[mygptrnn.git] / stack.py
index 543f04e..69a696d 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -25,23 +25,34 @@ def generate_sequences(
     )
 
     for t in range(nb_steps):
-        op = torch.randint(2, (nb,))
-        st = torch.randint(nb_stacks, (nb,))
-        op = op * (stack_counts[k, st] > 0)
-        if values is None:
+        op = torch.randint(2, (nb,))  # what operation (push/pop)
+        st = torch.randint(nb_stacks, (nb,))  # on what stack
+        op = op * (stack_counts[k, st] > 0)  # can only push is stack is empty
+
+        if values is None:  # we can use all the values
             val_push = torch.randint(10**nb_digits, (nb,))
-        else:
+        else:  # values are constrained (e.g. to have train/test values disjoint)
             val_push = values[torch.randint(values.size(0), (nb,))]
-        val_pop = stack[
+
+        val_pop = stack[  # if we were popping, what value would that be?
             k,
             st,
-            (stack_counts[k, st] - 1).clamp(min=0),
+            (stack_counts[k, st] - 1).clamp(min=0),  # deal with empty stack
         ]
+
+        # we always push the value, but it will be lost if we pop
+        # since we will move the count down
         stack[k, st, stack_counts[k, st]] = val_push
         recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+
+        # we increase the stack count only when we actually push
         stack_counts[k[op == 0], st[op == 0]] += 1
         stack_counts[k[op == 1], st[op == 1]] -= 1
+
+        # add the operation number to the sequence, that incude the stack number
         result[:, (1 + nb_digits) * t] = st * 2 + op
+
+        # add the digits to the sequence
         for d in range(nb_digits):
             result[:, (1 + nb_digits) * t + 1 + d] = (
                 (op * val_pop + (1 - op) * val_push) // (10**d)
@@ -57,29 +68,49 @@ def remove_popped_values(seq, nb_stacks, nb_digits):
         seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
 
 
-def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
-    assert seq.size(0) % (1 + nb_digits) == 0
-    s = ""
-    for t in range(seq.size(0) // (1 + nb_digits)):
-        n_op = seq[(1 + nb_digits) * t]
-        if t > 0:
-            s += " "
-        if recorded_stack_counts is not None:
-            s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
-        s += f"POP" if n_op % 2 == 1 else f"PSH"
-        if nb_stacks > 1:
-            s += f"_{n_op//2}"
-        for d in range(nb_digits):
-            if seq[(1 + nb_digits) * t + 1 + d] == -1:
-                s += " ?"
-            else:
-                s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
-    return s
+def seq_to_str(seq, nb_stacks, nb_digits):
+    def n_to_str(n):
+        if n < 0:
+            return "?"
+        elif n < 2 * nb_stacks:
+            s = f"POP" if n % 2 == 1 else f"PSH"
+            if nb_stacks > 1:
+                s += f"_{n//2}"
+                return s
+        elif n < 2 * nb_stacks + 10:
+            return f"{n - 2 * nb_stacks}"
+        else:
+            return "#"
+
+    return " ".join([n_to_str(x.item()) for x in seq])
 
 
 ######################################################################
 
 if __name__ == "__main__":
+    seq, recorded_stack_counts = generate_sequences(
+        nb=3,
+        nb_steps=6,
+        nb_stacks=3,
+        nb_digits=3,
+    )
+
+    sep = torch.full((seq.size(0), 1), seq.max() + 1)
+
+    seq = torch.cat([seq, sep, seq], dim=1)
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+    remove_popped_values(seq, 3, 3)
+
+    print()
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+    exit(0)
+
     nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
     seq, recorded_stack_counts = generate_sequences(
         nb=nb,
@@ -101,6 +132,8 @@ if __name__ == "__main__":
 
     print("-- PREPARED FOR TEST -----------------")
 
+    print("SANITY", seq.size())
+
     remove_popped_values(seq, nb_stacks, nb_digits)
 
     for n in range(min(10, seq.size(0))):