X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=543f04e5628b7276769241922c3770afd5d416f9;hb=HEAD;hp=219a1ad013692e49429d85866950b1afe7e5cbf6;hpb=4502a109727b0424ff6d4df90f17b361524f9e73;p=picoclvr.git diff --git a/stack.py b/stack.py index 219a1ad..543f04e 100755 --- a/stack.py +++ b/stack.py @@ -38,7 +38,7 @@ def generate_sequences( (stack_counts[k, st] - 1).clamp(min=0), ] stack[k, st, stack_counts[k, st]] = val_push - recorded_stack_counts[:, (1 + nb_digits) * t + 1] = stack_counts[k, st] + recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] stack_counts[k[op == 0], st[op == 0]] += 1 stack_counts[k[op == 1], st[op == 1]] -= 1 result[:, (1 + nb_digits) * t] = st * 2 + op @@ -64,6 +64,8 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): n_op = seq[(1 + nb_digits) * t] if t > 0: s += " " + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " s += f"POP" if n_op % 2 == 1 else f"PSH" if nb_stacks > 1: s += f"_{n_op//2}" @@ -72,15 +74,13 @@ def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): s += " ?" else: s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" - if recorded_stack_counts is not None: - s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] " return s ###################################################################### if __name__ == "__main__": - nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1 + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 seq, recorded_stack_counts = generate_sequences( nb=nb, nb_steps=nb_steps, @@ -88,13 +88,18 @@ if __name__ == "__main__": nb_digits=nb_digits, ) - print("-- TRAIN -----------------------------") - for n in range(min(10, seq.size(0))): - # print(seq_to_str(seq[n], recorded_stack_counts[n])) - print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) - - print("-- TEST ------------------------------") + print( + seq_to_str( + seq[n], + nb_stacks=nb_stacks, + nb_digits=nb_digits, + recorded_stack_counts=recorded_stack_counts[n], + ) + ) + # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) + + print("-- PREPARED FOR TEST -----------------") remove_popped_values(seq, nb_stacks, nb_digits)