From 76671c582f029aa67fce2626764b02e8d9e2dbeb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 1 Jul 2023 19:42:47 +0200 Subject: [PATCH] Update. --- main.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- stack.py | 50 ++++++++++++++++-------- 2 files changed, 148 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 45bddb7..0323d02 100755 --- a/main.py +++ b/main.py @@ -32,7 +32,7 @@ parser = argparse.ArgumentParser( ) parser.add_argument( - "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake" + "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack" ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -106,6 +106,15 @@ parser.add_argument("--snake_nb_colors", type=int, default=5) parser.add_argument("--snake_length", type=int, default=200) +############################## +# Snake options + +parser.add_argument("--stack_nb_steps", type=int, default=25) + +parser.add_argument("--stack_nb_stacks", type=int, default=1) + +parser.add_argument("--stack_nb_values", type=int, default=10) + ###################################################################### args = parser.parse_args() @@ -135,18 +144,32 @@ default_args = { "picoclvr": { "nb_epochs": 25, "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, }, "mnist": { "nb_epochs": 25, "batch_size": 10, + "nb_train_samples": 250000, + "nb_test_samples": 10000, }, "maze": { "nb_epochs": 25, "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, }, "snake": { "nb_epochs": 5, "batch_size": 25, + "nb_train_samples": 250000, + "nb_test_samples": 10000, + }, + "stack": { + "nb_epochs": 25, + "batch_size": 25, + "nb_train_samples": 10000, + "nb_test_samples": 1000, }, } @@ -841,6 +864,86 @@ class TaskSnake(Task): ###################################################################### +import stack + + +class TaskStack(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + nb_steps, + nb_stacks, + nb_values, + device=torch.device("cpu"), + ): + self.batch_size = batch_size + self.nb_steps = nb_steps + self.nb_stacks = nb_stacks + self.nb_values = nb_values + self.device = device + + self.train_input, self.train_stack_counts = stack.generate_sequences( + nb_train_samples, nb_steps, nb_stacks, nb_values, self.device + ) + + self.test_input, self.test_stack_counts = stack.generate_sequences( + nb_test_samples, nb_steps, nb_stacks, nb_values, self.device + ) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results(self, n_epoch, model): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + def compute_nb_correct(input): + result = input.clone() + stack.remove_poped_values(result,self.nb_stacks) + ar_mask = (result != input).long() + result *= 1 - ar_mask + + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + + nb_total = ar_mask.sum() + + nb_correct = ( + (result == input).long() * ar_mask + ).sum() + + return nb_total, nb_correct + + test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000]) + + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + model.train(t) + + +###################################################################### + + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -902,6 +1005,17 @@ elif args.task == "snake": device=device, ) +elif args.task == "stack": + task = TaskStack( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + nb_steps = args.stack_nb_steps, + nb_stacks = args.stack_nb_stacks, + nb_values = args.stack_nb_values, + device=device, + ) + else: raise ValueError(f"Unknown task {args.task}") diff --git a/stack.py b/stack.py index dc494bb..312b39f 100755 --- a/stack.py +++ b/stack.py @@ -13,44 +13,52 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate(nb, nb_steps, nb_stacks, nb_values): +def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) - stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64) + stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64) - depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64) + recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64) for t in range(nb_steps): op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) - op = op * (stack_pointers[k, st] > 0) + op = op * (stack_counts[k, st] > 0) val_push = torch.randint(nb_values, (nb,)) val_pop = stack[ k, st, - (stack_pointers[k, st] - 1).clamp(min=0), + (stack_counts[k, st] - 1).clamp(min=0), ] - stack[k, st, stack_pointers[k, st]] = val_push - depth_counts[:, 2 * t + 1] = stack_pointers[k, st] - stack_pointers[k[op == 0], st[op == 0]] += 1 - stack_pointers[k[op == 1], st[op == 1]] -= 1 + stack[k, st, stack_counts[k, st]] = val_push + recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st] + stack_counts[k[op == 0], st[op == 0]] += 1 + stack_counts[k[op == 1], st[op == 1]] -= 1 result[:, 2 * t] = st * 2 + op result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks - return result, depth_counts + return result.to(device), recorded_stack_counts.to(device) -def seq_to_str(seq, depth_counts=None): +def remove_poped_values(seq, nb_stacks): + m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long() + seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:] + + +def seq_to_str(seq, recorded_stack_counts=None): assert seq.size(0) % 2 == 0 s = "" for t in range(seq.size(0) // 2): op = seq[2 * t] - op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}" - val = seq[2 * t + 1] - 2 * nb_stacks + op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}" + if seq[2 * t + 1] == -1: + val = "?" + else: + val = seq[2 * t + 1] - 2 * nb_stacks if t > 0: s += " " - if depth_counts is not None: - s += f"[{depth_counts[2*t+1]}] " + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[2*t+1]}] " s += f"{op} {val}" return s @@ -59,9 +67,17 @@ def seq_to_str(seq, depth_counts=None): if __name__ == "__main__": nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5 - seq, depth_counts = generate( + seq, recorded_stack_counts = generate_sequences( nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values ) for n in range(min(10, seq.size(0))): - print(seq_to_str(seq[n], depth_counts[n])) + # print(seq_to_str(seq[n], recorded_stack_counts[n])) + print(seq_to_str(seq[n])) + + print("--------------------------------------") + + remove_poped_values(seq, nb_stacks) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n])) -- 2.20.1