Update.
[mygptrnn.git] / stack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 ######################################################################
11
12 # CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
13 # CODE_VAL=val + 2 * nb_stacks
14
15
16 def generate_sequences(
17     nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
18 ):
19     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
20     stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
21     k = torch.arange(nb)
22     result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
23     recorded_stack_counts = torch.zeros(
24         nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
25     )
26
27     for t in range(nb_steps):
28         op = torch.randint(2, (nb,))  # what operation (push/pop)
29         st = torch.randint(nb_stacks, (nb,))  # on what stack
30         op = op * (stack_counts[k, st] > 0)  # can only push is stack is empty
31
32         if values is None:  # we can use all the values
33             val_push = torch.randint(10**nb_digits, (nb,))
34         else:  # values are constrained (e.g. to have train/test values disjoint)
35             val_push = values[torch.randint(values.size(0), (nb,))]
36
37         val_pop = stack[  # if we were popping, what value would that be?
38             k,
39             st,
40             (stack_counts[k, st] - 1).clamp(min=0),  # deal with empty stack
41         ]
42
43         # we always push the value, but it will be lost if we pop
44         # since we will move the count down
45         stack[k, st, stack_counts[k, st]] = val_push
46         recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
47
48         # we increase the stack count only when we actually push
49         stack_counts[k[op == 0], st[op == 0]] += 1
50         stack_counts[k[op == 1], st[op == 1]] -= 1
51
52         # add the operation number to the sequence, that incude the stack number
53         result[:, (1 + nb_digits) * t] = st * 2 + op
54
55         # add the digits to the sequence
56         for d in range(nb_digits):
57             result[:, (1 + nb_digits) * t + 1 + d] = (
58                 (op * val_pop + (1 - op) * val_push) // (10**d)
59             ) % 10 + 2 * nb_stacks
60
61     return result.to(device), recorded_stack_counts.to(device)
62
63
64 def remove_popped_values(seq, nb_stacks, nb_digits):
65     m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
66     for d in range(nb_digits):
67         k = d + 1
68         seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
69
70
71 def seq_to_str(seq, nb_stacks, nb_digits):
72     def n_to_str(n):
73         if n < 0:
74             return "?"
75         elif n < 2 * nb_stacks:
76             s = f"POP" if n % 2 == 1 else f"PSH"
77             if nb_stacks > 1:
78                 s += f"_{n//2}"
79                 return s
80         elif n < 2 * nb_stacks + 10:
81             return f"{n - 2 * nb_stacks}"
82         else:
83             return "#"
84
85     return " ".join([n_to_str(x.item()) for x in seq])
86
87
88 ######################################################################
89
90 if __name__ == "__main__":
91     seq, recorded_stack_counts = generate_sequences(
92         nb=3,
93         nb_steps=6,
94         nb_stacks=3,
95         nb_digits=3,
96     )
97
98     sep = torch.full((seq.size(0), 1), seq.max() + 1)
99
100     seq = torch.cat([seq, sep, seq], dim=1)
101
102     for n in range(min(10, seq.size(0))):
103         print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
104
105     remove_popped_values(seq, 3, 3)
106
107     print()
108
109     for n in range(min(10, seq.size(0))):
110         print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
111
112     exit(0)
113
114     nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
115     seq, recorded_stack_counts = generate_sequences(
116         nb=nb,
117         nb_steps=nb_steps,
118         nb_stacks=nb_stacks,
119         nb_digits=nb_digits,
120     )
121
122     for n in range(min(10, seq.size(0))):
123         print(
124             seq_to_str(
125                 seq[n],
126                 nb_stacks=nb_stacks,
127                 nb_digits=nb_digits,
128                 recorded_stack_counts=recorded_stack_counts[n],
129             )
130         )
131         # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
132
133     print("-- PREPARED FOR TEST -----------------")
134
135     print("SANITY", seq.size())
136
137     remove_popped_values(seq, nb_stacks, nb_digits)
138
139     for n in range(min(10, seq.size(0))):
140         print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))