3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, torchvision
10 ######################################################################
12 # CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
13 # CODE_VAL=val + 2 * nb_stacks
16 def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
17 stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
18 stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
20 result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
21 recorded_stack_counts = torch.zeros(
22 nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
25 for t in range(nb_steps):
26 op = torch.randint(2, (nb,))
27 st = torch.randint(nb_stacks, (nb,))
28 op = op * (stack_counts[k, st] > 0)
29 val_push = torch.randint(10**nb_digits, (nb,))
33 (stack_counts[k, st] - 1).clamp(min=0),
35 stack[k, st, stack_counts[k, st]] = val_push
36 recorded_stack_counts[:, (1 + nb_digits) * t + 1] = stack_counts[k, st]
37 stack_counts[k[op == 0], st[op == 0]] += 1
38 stack_counts[k[op == 1], st[op == 1]] -= 1
39 result[:, (1 + nb_digits) * t] = st * 2 + op
40 for d in range(nb_digits):
41 result[:, (1 + nb_digits) * t + 1 + d] = (
42 (op * val_pop + (1 - op) * val_push) // (10**d)
43 ) % 10 + 2 * nb_stacks
45 return result.to(device), recorded_stack_counts.to(device)
48 def remove_popped_values(seq, nb_stacks, nb_digits):
49 m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
50 for d in range(nb_digits):
52 seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
55 def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
56 assert seq.size(0) % (1 + nb_digits) == 0
58 for t in range(seq.size(0) // (1 + nb_digits)):
59 n_op = seq[(1 + nb_digits) * t]
62 s += f"POP" if n_op % 2 == 1 else f"PSH"
65 for d in range(nb_digits):
66 if seq[(1 + nb_digits) * t + 1 + d] == -1:
69 s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
70 if recorded_stack_counts is not None:
71 s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] "
75 ######################################################################
77 if __name__ == "__main__":
78 nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1
79 seq, recorded_stack_counts = generate_sequences(
86 print("-- TRAIN -----------------------------")
88 for n in range(min(10, seq.size(0))):
89 # print(seq_to_str(seq[n], recorded_stack_counts[n]))
90 print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
92 print("-- TEST ------------------------------")
94 remove_popped_values(seq, nb_stacks, nb_digits)
96 for n in range(min(10, seq.size(0))):
97 print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))