Oups
[picoclvr.git] / stack.py
index 219a1ad..543f04e 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -38,7 +38,7 @@ def generate_sequences(
             (stack_counts[k, st] - 1).clamp(min=0),
         ]
         stack[k, st, stack_counts[k, st]] = val_push
-        recorded_stack_counts[:, (1 + nb_digits) * t + 1] = stack_counts[k, st]
+        recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
         stack_counts[k[op == 0], st[op == 0]] += 1
         stack_counts[k[op == 1], st[op == 1]] -= 1
         result[:, (1 + nb_digits) * t] = st * 2 + op
@@ -64,6 +64,8 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
         n_op = seq[(1 + nb_digits) * t]
         if t > 0:
             s += " "
+        if recorded_stack_counts is not None:
+            s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
         s += f"POP" if n_op % 2 == 1 else f"PSH"
         if nb_stacks > 1:
             s += f"_{n_op//2}"
@@ -72,15 +74,13 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
                 s += " ?"
             else:
                 s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
-        if recorded_stack_counts is not None:
-            s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] "
     return s
 
 
 ######################################################################
 
 if __name__ == "__main__":
-    nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1
+    nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
     seq, recorded_stack_counts = generate_sequences(
         nb=nb,
         nb_steps=nb_steps,
@@ -88,13 +88,18 @@ if __name__ == "__main__":
         nb_digits=nb_digits,
     )
 
-    print("-- TRAIN -----------------------------")
-
     for n in range(min(10, seq.size(0))):
-        # print(seq_to_str(seq[n], recorded_stack_counts[n]))
-        print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
-
-    print("-- TEST ------------------------------")
+        print(
+            seq_to_str(
+                seq[n],
+                nb_stacks=nb_stacks,
+                nb_digits=nb_digits,
+                recorded_stack_counts=recorded_stack_counts[n],
+            )
+        )
+        # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
+
+    print("-- PREPARED FOR TEST -----------------")
 
     remove_popped_values(seq, nb_stacks, nb_digits)