X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=69a696d115100b94fdb720efe261b184d044881e;hb=HEAD;hp=543f04e5628b7276769241922c3770afd5d416f9;hpb=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git diff --git a/stack.py b/stack.py index 543f04e..69a696d 100755 --- a/stack.py +++ b/stack.py @@ -25,23 +25,34 @@ def generate_sequences( ) for t in range(nb_steps): - op = torch.randint(2, (nb,)) - st = torch.randint(nb_stacks, (nb,)) - op = op * (stack_counts[k, st] > 0) - if values is None: + op = torch.randint(2, (nb,)) # what operation (push/pop) + st = torch.randint(nb_stacks, (nb,)) # on what stack + op = op * (stack_counts[k, st] > 0) # can only push is stack is empty + + if values is None: # we can use all the values val_push = torch.randint(10**nb_digits, (nb,)) - else: + else: # values are constrained (e.g. to have train/test values disjoint) val_push = values[torch.randint(values.size(0), (nb,))] - val_pop = stack[ + + val_pop = stack[ # if we were popping, what value would that be? k, st, - (stack_counts[k, st] - 1).clamp(min=0), + (stack_counts[k, st] - 1).clamp(min=0), # deal with empty stack ] + + # we always push the value, but it will be lost if we pop + # since we will move the count down stack[k, st, stack_counts[k, st]] = val_push recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] + + # we increase the stack count only when we actually push stack_counts[k[op == 0], st[op == 0]] += 1 stack_counts[k[op == 1], st[op == 1]] -= 1 + + # add the operation number to the sequence, that incude the stack number result[:, (1 + nb_digits) * t] = st * 2 + op + + # add the digits to the sequence for d in range(nb_digits): result[:, (1 + nb_digits) * t + 1 + d] = ( (op * val_pop + (1 - op) * val_push) // (10**d) @@ -57,29 +68,49 @@ def remove_popped_values(seq, nb_stacks, nb_digits): seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:] -def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): - assert seq.size(0) % (1 + nb_digits) == 0 - s = "" - for t in range(seq.size(0) // (1 + nb_digits)): - n_op = seq[(1 + nb_digits) * t] - if t > 0: - s += " " - if recorded_stack_counts is not None: - s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " - s += f"POP" if n_op % 2 == 1 else f"PSH" - if nb_stacks > 1: - s += f"_{n_op//2}" - for d in range(nb_digits): - if seq[(1 + nb_digits) * t + 1 + d] == -1: - s += " ?" - else: - s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" - return s +def seq_to_str(seq, nb_stacks, nb_digits): + def n_to_str(n): + if n < 0: + return "?" + elif n < 2 * nb_stacks: + s = f"POP" if n % 2 == 1 else f"PSH" + if nb_stacks > 1: + s += f"_{n//2}" + return s + elif n < 2 * nb_stacks + 10: + return f"{n - 2 * nb_stacks}" + else: + return "#" + + return " ".join([n_to_str(x.item()) for x in seq]) ###################################################################### if __name__ == "__main__": + seq, recorded_stack_counts = generate_sequences( + nb=3, + nb_steps=6, + nb_stacks=3, + nb_digits=3, + ) + + sep = torch.full((seq.size(0), 1), seq.max() + 1) + + seq = torch.cat([seq, sep, seq], dim=1) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3)) + + remove_popped_values(seq, 3, 3) + + print() + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3)) + + exit(0) + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 seq, recorded_stack_counts = generate_sequences( nb=nb, @@ -101,6 +132,8 @@ if __name__ == "__main__": print("-- PREPARED FOR TEST -----------------") + print("SANITY", seq.size()) + remove_popped_values(seq, nb_stacks, nb_digits) for n in range(min(10, seq.size(0))):