From 87da428a5ab9ac3cd49ab22bd27e572d0b16f29c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 2 Jul 2023 17:45:08 +0200 Subject: [PATCH] Update. --- main.py | 60 ++++++++++++++++++++++++++++++++++------------------- stack.py | 63 +++++++++++++++++++++++++++++++++----------------------- 2 files changed, 76 insertions(+), 47 deletions(-) diff --git a/main.py b/main.py index 14b1bc3..314a961 100755 --- a/main.py +++ b/main.py @@ -45,9 +45,9 @@ parser.add_argument("--nb_epochs", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None) -parser.add_argument("--nb_train_samples", type=int, default=250000) +parser.add_argument("--nb_train_samples", type=int, default=None) -parser.add_argument("--nb_test_samples", type=int, default=10000) +parser.add_argument("--nb_test_samples", type=int, default=None) parser.add_argument("--optim", type=str, default="adam") @@ -113,7 +113,7 @@ parser.add_argument("--stack_nb_steps", type=int, default=100) parser.add_argument("--stack_nb_stacks", type=int, default=1) -parser.add_argument("--stack_nb_values", type=int, default=10) +parser.add_argument("--stack_nb_digits", type=int, default=1) ###################################################################### @@ -214,7 +214,7 @@ def masked_inplace_autoregression( # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2) batches = zip(input.split(batch_size), ar_mask.split(batch_size)) if progress_bar_desc is not None: - tqdm.tqdm( + batches = tqdm.tqdm( batches, dynamic_ncols=True, desc=progress_bar_desc, @@ -875,28 +875,28 @@ class TaskStack(Task): batch_size, nb_steps, nb_stacks, - nb_values, + nb_digits, device=torch.device("cpu"), ): self.batch_size = batch_size self.nb_steps = nb_steps self.nb_stacks = nb_stacks - self.nb_values = nb_values + self.nb_digits = nb_digits self.device = device self.train_input, self.train_stack_counts = stack.generate_sequences( - nb_train_samples, nb_steps, nb_stacks, nb_values, self.device + nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device ) self.test_input, self.test_stack_counts = stack.generate_sequences( - nb_test_samples, nb_steps, nb_stacks, nb_values, self.device + nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device ) mask = self.test_input.clone() - stack.remove_poped_values(mask,self.nb_stacks) - mask=(mask!=self.test_input) + stack.remove_popped_values(mask, self.nb_stacks, self.nb_digits) + mask = mask != self.test_input counts = self.test_stack_counts.flatten()[mask.flatten()] - counts=F.one_hot(counts).sum(0) + counts = F.one_hot(counts).sum(0) log_string(f"stack_count {counts}") self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -923,19 +923,19 @@ class TaskStack(Task): def compute_nb_correct(input): result = input.clone() - stack.remove_poped_values(result,self.nb_stacks) + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - result *= 1 - ar_mask - masked_inplace_autoregression( model, self.batch_size, result, ar_mask, device=self.device ) - nb_total = ar_mask.sum() + errors = ((result != input).long() * ar_mask).reshape( + -1, 1 + self.nb_digits + ) + ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits) - nb_correct = ( - (result == input).long() * ar_mask - ).sum() + nb_total = ar_mask.max(1).values.sum() + nb_correct = nb_total - errors.max(1).values.sum() return nb_total, nb_correct @@ -945,6 +945,24 @@ class TaskStack(Task): f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + input = self.test_input[:10, :20] + result = input.clone() + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + for n in range(result.size(0)): + log_string( + f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + for n in range(result.size(0)): + log_string( + f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + model.train(t) @@ -1017,9 +1035,9 @@ elif args.task == "stack": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - nb_steps = args.stack_nb_steps, - nb_stacks = args.stack_nb_stacks, - nb_values = args.stack_nb_values, + nb_steps=args.stack_nb_steps, + nb_stacks=args.stack_nb_stacks, + nb_digits=args.stack_nb_digits, device=device, ) diff --git a/stack.py b/stack.py index ba452aa..675182e 100755 --- a/stack.py +++ b/stack.py @@ -13,74 +13,85 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")): +def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")): stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) - result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64) - recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64) + result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64) + recorded_stack_counts = torch.zeros( + nb, (1 + nb_digits) * nb_steps, dtype=torch.int64 + ) for t in range(nb_steps): op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) op = op * (stack_counts[k, st] > 0) - val_push = torch.randint(nb_values, (nb,)) + val_push = torch.randint(10**nb_digits, (nb,)) val_pop = stack[ k, st, (stack_counts[k, st] - 1).clamp(min=0), ] stack[k, st, stack_counts[k, st]] = val_push - recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st] + recorded_stack_counts[:, (1 + nb_digits) * t + 1] = stack_counts[k, st] stack_counts[k[op == 0], st[op == 0]] += 1 stack_counts[k[op == 1], st[op == 1]] -= 1 - result[:, 2 * t] = st * 2 + op - result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks + result[:, (1 + nb_digits) * t] = st * 2 + op + for d in range(nb_digits): + result[:, (1 + nb_digits) * t + 1 + d] = ( + (op * val_pop + (1 - op) * val_push) // (10**d) + ) % 10 + 2 * nb_stacks return result.to(device), recorded_stack_counts.to(device) -def remove_poped_values(seq, nb_stacks): +def remove_popped_values(seq, nb_stacks, nb_digits): m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long() - seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:] + for d in range(nb_digits): + k = d + 1 + seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:] -def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None): - assert seq.size(0) % 2 == 0 +def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): + assert seq.size(0) % (1 + nb_digits) == 0 s = "" - for t in range(seq.size(0) // 2): - n_op = seq[2 * t] - op = f"POP" if n_op % 2 == 1 else f"PSH" - if show_stack_nb: op+=f"_{n_op//2}" - if seq[2 * t + 1] == -1: - val = "?" - else: - val = seq[2 * t + 1] - 2 * nb_stacks + for t in range(seq.size(0) // (1 + nb_digits)): + n_op = seq[(1 + nb_digits) * t] if t > 0: s += " " + s += f"POP" if n_op % 2 == 1 else f"PSH" + if nb_stacks > 1: + s += f"_{n_op//2}" + for d in range(nb_digits): + if seq[(1 + nb_digits) * t + 1 + d] == -1: + s += " ?" + else: + s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" if recorded_stack_counts is not None: - s += f"[{recorded_stack_counts[2*t+1]}] " - s += f"{op} {val}" + s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] " return s ###################################################################### if __name__ == "__main__": - nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5 + nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1 seq, recorded_stack_counts = generate_sequences( - nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values + nb=nb, + nb_steps=nb_steps, + nb_stacks=nb_stacks, + nb_digits=nb_digits, ) print("-- TRAIN -----------------------------") for n in range(min(10, seq.size(0))): # print(seq_to_str(seq[n], recorded_stack_counts[n])) - print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1)) + print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) print("-- TEST ------------------------------") - remove_poped_values(seq, nb_stacks) + remove_popped_values(seq, nb_stacks, nb_digits) for n in range(min(10, seq.size(0))): - print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1)) + print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) -- 2.20.1