From: François Fleuret Date: Mon, 10 Jul 2023 07:10:36 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=994d2408781ebaed6da16b10b2b3ebedeff82756;p=picoclvr.git Update. --- diff --git a/expr.py b/expr.py index f294d68..847cd36 100755 --- a/expr.py +++ b/expr.py @@ -16,56 +16,46 @@ def random_var(nb_variables=None, variables=None): return l[torch.randint(len(l), (1,)).item()] -def random_expr(variables, budget): +def random_expr(variables, operand_max, budget): if budget <= 5: op = torch.randint(2, (1,)).item() if op == 0 and len(variables) > 0: return random_var(variables=variables) else: - return str(torch.randint(10, (1,)).item()) + return str(torch.randint(operand_max + 1, (1,)).item()) else: op = torch.randint(3, (1,)).item() if op == 0: - e = random_expr(variables, budget - 2) + e = random_expr(variables, operand_max, budget - 2) if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"): return "(" + e + ")" else: return e else: b = 2 + torch.randint(budget - 5, (1,)).item() - e1 = random_expr(variables, b) - e2 = random_expr(variables, budget - b - 1) + e1 = random_expr(variables, operand_max, b) + e2 = random_expr(variables, operand_max, budget - b - 1) if op == 1: return e1 + "+" + e2 elif op == 2: return e1 + "*" + e2 -def generate_program(nb_variables, length): +def generate_program(nb_variables, operand_max, length): s = "" variables = set() while len(s) < length: v = random_var(nb_variables=nb_variables) - s += v + "=" + random_expr(variables, budget=20) + ";" + s += v + "=" + random_expr(variables, operand_max, budget=20) + ";" variables.add(v) return s, variables -def extract_results(seq): - f = lambda a: (a[0], -1 if a[1] == "" else int(a[1])) - results = [ - dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)]) - for s in seq - ] - return results - - -def generate_sequences(nb, nb_variables=5, length=20): +def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99): assert nb_variables <= 26 sequences = [] - result_max = 99 for n in range(nb): # We take length itself half of the time, and uniform between @@ -75,7 +65,7 @@ def generate_sequences(nb, nb_variables=5, length=20): l = min(length, 1 + torch.randint(length * 2, (1,)).item()) result = None while result == None or max(result.values()) > result_max: - p, v = generate_program(nb_variables, l) + p, v = generate_program(nb_variables, operand_max, l) v = ", ".join(['"' + v + '": ' + v for v in v]) ldict = {} exec(p + "result={" + v + "}", globals(), ldict) @@ -88,6 +78,15 @@ def generate_sequences(nb, nb_variables=5, length=20): return sequences +def extract_results(seq): + f = lambda a: (a[0], -1 if a[1] == "" else int(a[1])) + results = [ + dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)]) + for s in seq + ] + return results + + if __name__ == "__main__": import time diff --git a/main.py b/main.py index abed321..80f2733 100755 --- a/main.py +++ b/main.py @@ -127,6 +127,10 @@ parser.add_argument("--expr_nb_variables", type=int, default=5) parser.add_argument("--expr_sequence_length", type=int, default=40) +parser.add_argument("--expr_operand_max", type=int, default=9) + +parser.add_argument("--expr_result_max", type=int, default=99) + parser.add_argument("--expr_input_file", type=str, default=None) ###################################################################### @@ -307,6 +311,8 @@ elif args.task == "expr": nb_test_samples=args.nb_test_samples, nb_variables=args.expr_nb_variables, sequence_length=args.expr_sequence_length, + operand_max=args.expr_operand_max, + result_max=args.expr_result_max, batch_size=args.batch_size, device=device, ) diff --git a/tasks.py b/tasks.py index 4d7e90e..82d965b 100755 --- a/tasks.py +++ b/tasks.py @@ -30,10 +30,19 @@ def masked_inplace_autoregression( total=input.size(0) // batch_size, ) - for input, ar_mask in batches: - model.masked_inplace_autoregression( - input, ar_mask, forbidden_tokens, deterministic_synthesis - ) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, ar_mask in batches: + model.masked_inplace_autoregression( + input, ar_mask, forbidden_tokens, deterministic_synthesis + ) + + model.train(t) + + +###################################################################### class Task: @@ -447,70 +456,64 @@ class Maze(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - with torch.autograd.no_grad(): - t = model.training - model.eval() - - train_nb_total, train_nb_correct, count = self.compute_error( - model, - "train", - nb_to_use=1000, - deterministic_synthesis=deterministic_synthesis, - ) - logger( - f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - ) - - test_nb_total, test_nb_correct, count = self.compute_error( - model, - "test", - nb_to_use=1000, - deterministic_synthesis=deterministic_synthesis, - ) - logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) + train_nb_total, train_nb_correct, count = self.compute_error( + model, + "train", + nb_to_use=1000, + deterministic_synthesis=deterministic_synthesis, + ) + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) - if count is not None: - proportion_optimal = count.diagonal().sum().float() / count.sum() - logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%") - with open( - os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" - ) as f: - for i in range(count.size(0)): - for j in range(count.size(1)): - eol = " " if j < count.size(1) - 1 else "\n" - f.write(f"{count[i,j]}{eol}") - - input = self.test_input[:48] - result = input.clone() - ar_mask = result.new_zeros(result.size()) - ar_mask[:, self.height * self.width :] = 1 - result *= 1 - ar_mask - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + test_nb_total, test_nb_correct, count = self.compute_error( + model, + "test", + nb_to_use=1000, + deterministic_synthesis=deterministic_synthesis, + ) + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) - mazes, paths = self.seq2map(input) - _, predicted_paths = self.seq2map(result) - - filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png") - maze.save_image( - filename, - mazes=mazes, - target_paths=paths, - predicted_paths=predicted_paths, - path_correct=maze.path_correctness(mazes, predicted_paths), - path_optimal=maze.path_optimality(paths, predicted_paths), - ) - logger(f"wrote {filename}") + if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%") + with open( + os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" + ) as f: + for i in range(count.size(0)): + for j in range(count.size(1)): + eol = " " if j < count.size(1) - 1 else "\n" + f.write(f"{count[i,j]}{eol}") + + input = self.test_input[:48] + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) - model.train(t) + mazes, paths = self.seq2map(input) + _, predicted_paths = self.seq2map(result) + + filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png") + maze.save_image( + filename, + mazes=mazes, + target_paths=paths, + predicted_paths=predicted_paths, + path_correct=maze.path_correctness(mazes, predicted_paths), + path_optimal=maze.path_optimality(paths, predicted_paths), + ) + logger(f"wrote {filename}") ###################################################################### @@ -577,59 +580,51 @@ class Snake(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - with torch.autograd.no_grad(): - t = model.training - model.eval() - - def compute_nb_correct(input, prior_visits): - result = input.clone() - i = torch.arange(result.size(1), device=result.device)[None, :] - ar_mask = ( - torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0) - .long() - .expand_as(result) - ) - result *= 1 - ar_mask - - # snake.solver(result,ar_mask) + def compute_nb_correct(input, prior_visits): + result = input.clone() + i = torch.arange(result.size(1), device=result.device)[None, :] + ar_mask = ( + torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0) + .long() + .expand_as(result) + ) + result *= 1 - ar_mask - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + # snake.solver(result,ar_mask) - nb_total = ((prior_visits > 0) * ar_mask).sum() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) - nb_correct = ( - (result == input).long() * (prior_visits > 0) * ar_mask - ).sum() + nb_total = ((prior_visits > 0) * ar_mask).sum() - # nb_total = result.size(0) - # nb_correct = ((result - input).abs().sum(1) == 0).sum() + nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum() - return nb_total, nb_correct + # nb_total = result.size(0) + # nb_correct = ((result - input).abs().sum(1) == 0).sum() - # train_nb_total, train_nb_correct = compute_nb_correct( - # self.train_input, self.train_prior_visits - # ) + return nb_total, nb_correct - # logger( - # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - # ) + # train_nb_total, train_nb_correct = compute_nb_correct( + # self.train_input, self.train_prior_visits + # ) - test_nb_total, test_nb_correct = compute_nb_correct( - self.test_input[:1000], self.test_prior_visits[:1000] - ) + # logger( + # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + # ) - logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) + test_nb_total, test_nb_correct = compute_nb_correct( + self.test_input[:1000], self.test_prior_visits[:1000] + ) - model.train(t) + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) ###################################################################### @@ -709,51 +704,10 @@ class Stack(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - with torch.autograd.no_grad(): - t = model.training - model.eval() - - def compute_nb_correct(input): - result = input.clone() - stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) - ar_mask = (result != input).long() - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) - - errors = ((result != input).long() * ar_mask).reshape( - -1, 1 + self.nb_digits - ) - ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits) - - nb_total = ar_mask.max(1).values.sum() - nb_correct = nb_total - errors.max(1).values.sum() - - return nb_total, nb_correct - - test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000]) - - logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) - - ############################################################## - # Log a few generated sequences - input = self.test_input[:10, : 12 * (1 + self.nb_digits)] + def compute_nb_correct(input): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - - # for n in range(result.size(0)): - # logger( - # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - # ) - masked_inplace_autoregression( model, self.batch_size, @@ -763,13 +717,48 @@ class Stack(Task): device=self.device, ) - for n in range(result.size(0)): - logger( - f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - ) - ############################################################## + errors = ((result != input).long() * ar_mask).reshape( + -1, 1 + self.nb_digits + ) + ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits) + + nb_total = ar_mask.max(1).values.sum() + nb_correct = nb_total - errors.max(1).values.sum() + + return nb_total, nb_correct + + test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000]) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + ############################################################## + # Log a few generated sequences + input = self.test_input[:10, : 12 * (1 + self.nb_digits)] + result = input.clone() + stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + + # for n in range(result.size(0)): + # logger( + # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + # ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) - model.train(t) + for n in range(result.size(0)): + logger( + f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + ############################################################## ###################################################################### @@ -799,6 +788,8 @@ class Expr(Task): nb_test_samples, nb_variables, sequence_length, + operand_max, + result_max, batch_size, device=torch.device("cpu"), ): @@ -809,12 +800,16 @@ class Expr(Task): nb_train_samples, nb_variables=nb_variables, length=sequence_length, + operand_max=operand_max, + result_max=result_max, ) test_sequences = expr.generate_sequences( nb_test_samples, nb_variables=nb_variables, length=sequence_length, + operand_max=operand_max, + result_max=result_max, ) symbols = list(set("#" + "".join(train_sequences + test_sequences))) @@ -859,107 +854,101 @@ class Expr(Task): deterministic_synthesis, input_file=None, ): - with torch.autograd.no_grad(): - t = model.training - model.eval() - - def compute_nb_correct(input): - result = input.clone() - s = (result == self.space).long() - ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) - result = (1 - ar_mask) * result + ar_mask * self.filler - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + def compute_nb_correct(input): + result = input.clone() + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.filler + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) - nb_total = input.size(0) - nb_correct = (input == result).long().min(1).values.sum() + nb_total = input.size(0) + nb_correct = (input == result).long().min(1).values.sum() - ####################################################################### - # Comput predicted vs. true variable values + ####################################################################### + # Comput predicted vs. true variable values - nb_delta = torch.zeros(5, dtype=torch.int64) - nb_missed = 0 + nb_delta = torch.zeros(5, dtype=torch.int64) + nb_missed = 0 - values_input = expr.extract_results([self.seq2str(s) for s in input]) - values_result = expr.extract_results([self.seq2str(s) for s in result]) + values_input = expr.extract_results([self.seq2str(s) for s in input]) + values_result = expr.extract_results([self.seq2str(s) for s in result]) - for i, r in zip(values_input, values_result): - for n, vi in i.items(): - vr = r.get(n) - if vr is None or vr < 0: + for i, r in zip(values_input, values_result): + for n, vi in i.items(): + vr = r.get(n) + if vr is None or vr < 0: + nb_missed += 1 + else: + d = abs(vr - vi) + if d >= nb_delta.size(0): nb_missed += 1 else: - d = abs(vr - vi) - if d >= nb_delta.size(0): - nb_missed += 1 - else: - nb_delta[d] += 1 + nb_delta[d] += 1 - ###################################################################### + ###################################################################### - return nb_total, nb_correct, nb_delta, nb_missed + return nb_total, nb_correct, nb_delta, nb_missed - ( - test_nb_total, - test_nb_correct, - test_nb_delta, - test_nb_missed, - ) = compute_nb_correct(self.test_input[:10000]) + ( + test_nb_total, + test_nb_correct, + test_nb_delta, + test_nb_missed, + ) = compute_nb_correct(self.test_input[:10000]) - logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) - nb_total = test_nb_delta.sum() + test_nb_missed - for d in range(test_nb_delta.size(0)): - logger( - f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%" - ) + nb_total = test_nb_delta.sum() + test_nb_missed + for d in range(test_nb_delta.size(0)): logger( - f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%" + f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%" ) + logger( + f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%" + ) - ############################################################## - # Log a few generated sequences - if input_file is None: - input = self.test_input[:10] - else: - with open(input_file, "r") as f: - sequences = [e.strip() for e in f.readlines()] - sequences = [s + " " + "#" * 50 for s in sequences] - input = self.tensorize(sequences) - - result = input.clone() - s = (result == self.space).long() - ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) - result = (1 - ar_mask) * result + ar_mask * self.filler + ############################################################## + # Log a few generated sequences + if input_file is None: + input = self.test_input[:10] + else: + with open(input_file, "r") as f: + sequences = [e.strip() for e in f.readlines()] + sequences = [s + " " + "#" * 50 for s in sequences] + input = self.tensorize(sequences) - for n in range(result.size(0)): - logger(f"test_before {self.seq2str(result[n])}") + result = input.clone() + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.filler - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + for n in range(result.size(0)): + logger(f"test_before {self.seq2str(result[n])}") - correct = (1 - ar_mask) * self.space + ar_mask * input - for n in range(result.size(0)): - comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" - logger(f"test_after {self.seq2str(result[n])} {comment}") - logger(f"truth {self.seq2str(correct[n])}") - ############################################################## + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) - model.train(t) + correct = (1 - ar_mask) * self.space + ar_mask * input + for n in range(result.size(0)): + comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" + logger(f"test_after {self.seq2str(result[n])} {comment}") + logger(f"truth {self.seq2str(correct[n])}") + ############################################################## ######################################################################