Update.
[picoclvr.git] / stack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 ######################################################################
11
12 # CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
13 # CODE_VAL=val + 2 * nb_stacks
14
15
16 def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")):
17     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
18     stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
19     k = torch.arange(nb)
20     result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
21     recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
22
23     for t in range(nb_steps):
24         op = torch.randint(2, (nb,))
25         st = torch.randint(nb_stacks, (nb,))
26         op = op * (stack_counts[k, st] > 0)
27         val_push = torch.randint(nb_values, (nb,))
28         val_pop = stack[
29             k,
30             st,
31             (stack_counts[k, st] - 1).clamp(min=0),
32         ]
33         stack[k, st, stack_counts[k, st]] = val_push
34         recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st]
35         stack_counts[k[op == 0], st[op == 0]] += 1
36         stack_counts[k[op == 1], st[op == 1]] -= 1
37         result[:, 2 * t] = st * 2 + op
38         result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
39
40     return result.to(device), recorded_stack_counts.to(device)
41
42
43 def remove_poped_values(seq, nb_stacks):
44     m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
45     seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
46
47
48 def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None):
49     assert seq.size(0) % 2 == 0
50     s = ""
51     for t in range(seq.size(0) // 2):
52         n_op = seq[2 * t]
53         op = f"POP" if n_op % 2 == 1 else f"PSH"
54         if show_stack_nb: op+=f"_{n_op//2}"
55         if seq[2 * t + 1] == -1:
56             val = "?"
57         else:
58             val = seq[2 * t + 1] - 2 * nb_stacks
59         if t > 0:
60             s += " "
61         if recorded_stack_counts is not None:
62             s += f"[{recorded_stack_counts[2*t+1]}] "
63         s += f"{op} {val}"
64     return s
65
66
67 ######################################################################
68
69 if __name__ == "__main__":
70     nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
71     seq, recorded_stack_counts = generate_sequences(
72         nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
73     )
74
75     print("-- TRAIN -----------------------------")
76
77     for n in range(min(10, seq.size(0))):
78         # print(seq_to_str(seq[n], recorded_stack_counts[n]))
79         print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))
80
81     print("-- TEST ------------------------------")
82
83     remove_poped_values(seq, nb_stacks)
84
85     for n in range(min(10, seq.size(0))):
86         print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))