)
for t in range(nb_steps):
- op = torch.randint(2, (nb,))
- st = torch.randint(nb_stacks, (nb,))
- op = op * (stack_counts[k, st] > 0)
- if values is None:
+ op = torch.randint(2, (nb,)) # what operation (push/pop)
+ st = torch.randint(nb_stacks, (nb,)) # on what stack
+ op = op * (stack_counts[k, st] > 0) # can only push is stack is empty
+
+ if values is None: # we can use all the values
val_push = torch.randint(10**nb_digits, (nb,))
- else:
+ else: # values are constrained (e.g. to have train/test values disjoint)
val_push = values[torch.randint(values.size(0), (nb,))]
- val_pop = stack[
+
+ val_pop = stack[ # if we were popping, what value would that be?
k,
st,
- (stack_counts[k, st] - 1).clamp(min=0),
+ (stack_counts[k, st] - 1).clamp(min=0), # deal with empty stack
]
+
+ # we always push the value, but it will be lost if we pop
+ # since we will move the count down
stack[k, st, stack_counts[k, st]] = val_push
recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+
+ # we increase the stack count only when we actually push
stack_counts[k[op == 0], st[op == 0]] += 1
stack_counts[k[op == 1], st[op == 1]] -= 1
+
+ # add the operation number to the sequence, that incude the stack number
result[:, (1 + nb_digits) * t] = st * 2 + op
+
+ # add the digits to the sequence
for d in range(nb_digits):
result[:, (1 + nb_digits) * t + 1 + d] = (
(op * val_pop + (1 - op) * val_push) // (10**d)
seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
-def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
- assert seq.size(0) % (1 + nb_digits) == 0
- s = ""
- for t in range(seq.size(0) // (1 + nb_digits)):
- n_op = seq[(1 + nb_digits) * t]
- if t > 0:
- s += " "
- if recorded_stack_counts is not None:
- s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
- s += f"POP" if n_op % 2 == 1 else f"PSH"
- if nb_stacks > 1:
- s += f"_{n_op//2}"
- for d in range(nb_digits):
- if seq[(1 + nb_digits) * t + 1 + d] == -1:
- s += " ?"
- else:
- s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
- return s
+def seq_to_str(seq, nb_stacks, nb_digits):
+ def n_to_str(n):
+ if n < 0:
+ return "?"
+ elif n < 2 * nb_stacks:
+ s = f"POP" if n % 2 == 1 else f"PSH"
+ if nb_stacks > 1:
+ s += f"_{n//2}"
+ return s
+ elif n < 2 * nb_stacks + 10:
+ return f"{n - 2 * nb_stacks}"
+ else:
+ return "#"
+
+ return " ".join([n_to_str(x.item()) for x in seq])
######################################################################
if __name__ == "__main__":
+ seq, recorded_stack_counts = generate_sequences(
+ nb=3,
+ nb_steps=6,
+ nb_stacks=3,
+ nb_digits=3,
+ )
+
+ sep = torch.full((seq.size(0), 1), seq.max() + 1)
+
+ seq = torch.cat([seq, sep, seq], dim=1)
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+ remove_popped_values(seq, 3, 3)
+
+ print()
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+ exit(0)
+
nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
seq, recorded_stack_counts = generate_sequences(
nb=nb,
print("-- PREPARED FOR TEST -----------------")
+ print("SANITY", seq.size())
+
remove_popped_values(seq, nb_stacks, nb_digits)
for n in range(min(10, seq.size(0))):