From c921b95d0ea5b94a893447fbd4792e5047ba6e99 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Jun 2023 18:10:36 +0200 Subject: [PATCH] Update. --- main.py | 232 +++++++++++++++++++++++++++++++++++++----- maze.py | 311 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 520 insertions(+), 23 deletions(-) create mode 100755 maze.py diff --git a/main.py b/main.py index 08afb66..ae42544 100755 --- a/main.py +++ b/main.py @@ -30,6 +30,8 @@ parser = argparse.ArgumentParser( description="An implementation of GPT with cache to solve a toy geometric reasoning task." ) +parser.add_argument("--task", type=str, default="picoclvr") + parser.add_argument("--log_filename", type=str, default="train.log") parser.add_argument("--result_dir", type=str, default="results_default") @@ -73,19 +75,28 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options -parser.add_argument("--nb_colors", type=int, default=5) +parser.add_argument("--picoclvr_nb_colors", type=int, default=5) + +parser.add_argument("--picoclvr_height", type=int, default=12) + +parser.add_argument("--picoclvr_width", type=int, default=16) + +parser.add_argument("--picocvlr_prune_properties", type=str, default="none") + +############################## +# Maze options -parser.add_argument("--height", type=int, default=12) +parser.add_argument("--maze_height", type=int, default=13) -parser.add_argument("--width", type=int, default=16) +parser.add_argument("--maze_width", type=int, default=21) -parser.add_argument("--prune_properties", type=str, default="none") +parser.add_argument("--maze_nb_walls", type=int, default=15) ###################################################################### args = parser.parse_args() -assert args.prune_properties in {"none", "train+eval", "eval"} +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} try: os.mkdir(args.result_dir) @@ -311,8 +322,12 @@ class TaskPicoCLVR(Task): "rng_state": list(torch.get_rng_state()), } - log_string(f"generating {nb_train_samples+nb_test_samples} samples (can take some time)") - self.train_descr = generate_descr(nb_train_samples, "train", pruner=self.pruner_train) + log_string( + f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" + ) + self.train_descr = generate_descr( + nb_train_samples, "train", pruner=self.pruner_train + ) self.test_descr = generate_descr(nb_test_samples, "test", pruner=None) # Build the tokenizer @@ -445,29 +460,200 @@ class TaskPicoCLVR(Task): ###################################################################### -log_string(f"device {device}") +import maze + +class TaskMaze(Task): + def map2seq(self, *m): + return torch.cat([x.flatten(1) for x in m], 1) -def pruner_horizontal_green(p): + def seq2map(self, s): + s = s.reshape(s.size(0), -1, self.height, self.width) + return (s[:, k] for k in range(s.size(1))) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_walls, + device=torch.device("cpu"), + ): + self.batch_size = batch_size + self.height = height + self.width = width + self.device = device + + train_mazes, train_paths, train_policies = maze.create_maze_data( + nb_train_samples, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), + ) + self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) + self.train_policies = train_policies.flatten(-2).to(device) + + test_mazes, test_paths, test_policies = maze.create_maze_data( + nb_test_samples, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), + ) + self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) + self.test_policies = test_policies.flatten(-2).to(device) + + self.nb_codes = self.train_input.max() + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def policy_batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + policies = self.train_policies if split == "train" else self.test_policies + input = input[:, : self.height * self.width] + policies = policies * (input != maze.v_wall)[:, None] + + if nb_to_use > 0: + input = input[:nb_to_use] + policies = policies[:nb_to_use] + + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + zip(input.split(self.batch_size), policies.split(self.batch_size)), + dynamic_ncols=True, + desc=desc, + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def compute_error(self, model, split="train", nb_to_use=-1): + nb_total, nb_correct = 0, 0 + for input in task.batches(split, nb_to_use): + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + mazes, paths = self.seq2map(result) + nb_correct += maze.path_correctness(mazes, paths).long().sum() + nb_total += mazes.size(0) + + return nb_total, nb_correct + + def produce_results(self, n_epoch, model): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + train_nb_total, train_nb_correct = self.compute_error( + model, "train", nb_to_use=1000 + ) + log_string( + f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = self.compute_error( + model, "test", nb_to_use=1000 + ) + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + input = self.test_input[:48] + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + + mazes, paths = self.seq2map(input) + _, predicted_paths = self.seq2map(result) + filename = f"result_{n_epoch:04d}.png" + maze.save_image( + os.path.join(args.result_dir, filename), + mazes=mazes, + target_paths=paths, + predicted_paths=predicted_paths, + path_correct=maze.path_correctness(mazes, predicted_paths), + ) + log_string(f"wrote {filename}") + + model.train(t) + + +###################################################################### + + +def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) -task = TaskPicoCLVR( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, - height=args.height, - width=args.width, - nb_colors=args.nb_colors, - device=device, - pruner_train=pruner_horizontal_green - if args.prune_properties in {"train+eval"} - else None, - pruner_eval=(lambda p: not pruner_horizontal_green(p)) - if args.prune_properties in {"train+eval", "eval"} - else None, +picoclvr_pruner_train = ( + picoclvr_pruner_horizontal_green + if args.picocvlr_prune_properties in {"train+eval"} + else None +) + +picoclvr_pruner_eval = ( + (lambda p: not picoclvr_pruner_horizontal_green(p)) + if args.picocvlr_prune_properties in {"train+eval", "eval"} + else None ) +###################################################################### + +if args.task == "picoclvr": + task = TaskPicoCLVR( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.picoclvr_height, + width=args.picoclvr_width, + nb_colors=args.picoclvr_nb_colors, + device=device, + pruner_train=picoclvr_pruner_train, + pruner_eval=picoclvr_pruner_eval, + ) + +elif args.task == "maze": + task = TaskMaze( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, + device=device, + ) + +else: + raise ValueError(f"Unknown task {args.task}") + +###################################################################### + +log_string(f"device {device}") + vocabulary_size = task.vocabulary_size() log_string(f"vocabulary_size {vocabulary_size}") diff --git a/maze.py b/maze.py new file mode 100755 index 0000000..81afcd9 --- /dev/null +++ b/maze.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch, torchvision + +###################################################################### + +v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4 + + +def create_maze(h=11, w=17, nb_walls=8): + a, k = 0, 0 + + while k < nb_walls: + while True: + if a == 0: + m = torch.zeros(h, w, dtype=torch.int64) + m[0, :] = 1 + m[-1, :] = 1 + m[:, 0] = 1 + m[:, -1] = 1 + + r = torch.rand(4) + + if r[0] <= 0.5: + i1, i2, j = ( + int((r[1] * h).item()), + int((r[2] * h).item()), + int((r[3] * w).item()), + ) + i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2 + i1, i2 = min(i1, i2), max(i1, i2) + if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1: + m[i1 : i2 + 1, j] = 1 + break + else: + i, j1, j2 = ( + int((r[1] * h).item()), + int((r[2] * w).item()), + int((r[3] * w).item()), + ) + i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2 + j1, j2 = min(j1, j2), max(j1, j2) + if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1: + m[i, j1 : j2 + 1] = 1 + break + a += 1 + + if a > 10 * nb_walls: + a, k = 0, 0 + + k += 1 + + return m + + +###################################################################### + + +def compute_distance(walls, goal_i, goal_j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[goal_i, goal_j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + d = ( + torch.cat( + ( + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1], + ), + 0, + ).min(dim=0)[0] + + 1 + ) + + dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d) + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): + return dist * (1 - walls) + + +###################################################################### + + +def compute_policy(walls, goal_i, goal_j): + distance = compute_distance(walls, goal_i, goal_j) + distance = distance + walls.numel() * walls + + value = distance.new_full((4,) + distance.size(), walls.numel()) + value[0, :, 1:] = distance[:, :-1] # < + value[1, :, :-1] = distance[:, 1:] # > + value[2, 1:, :] = distance[:-1, :] # ^ + value[3, :-1, :] = distance[1:, :] # v + + proba = (value.min(dim=0)[0][None] == value).float() + proba = proba / proba.sum(dim=0)[None] + proba = proba * (1 - walls) + walls.float() / 4 + + return proba + + +def stationary_densities(mazes, policies): + policies = policies * (mazes != v_goal)[:, None] + start = (mazes == v_start).nonzero(as_tuple=True) + probas = mazes.new_zeros(mazes.size(), dtype=torch.float32) + pred_probas = probas.clone() + probas[start] = 1.0 + + while not pred_probas.equal(probas): + pred_probas.copy_(probas) + probas.zero_() + probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :] + probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :] + probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1] + probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:] + probas[start] = 1.0 + + return probas + + +###################################################################### + + +def mark_path(walls, i, j, goal_i, goal_j, policy): + action = torch.distributions.categorical.Categorical( + policy.permute(1, 2, 0) + ).sample() + n, nmax = 0, walls.numel() + while i != goal_i or j != goal_j: + di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]] + i, j = i + di, j + dj + assert walls[i, j] == 0 + walls[i, j] = v_path + n += 1 + assert n < nmax + + +def path_correctness(mazes, paths): + still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0 + reached = still_ok.new_zeros(still_ok.size()) + current, pred_current = paths.clone(), paths.new_zeros(paths.size()) + goal = (mazes == v_goal).long() + while not pred_current.equal(current): + pred_current.copy_(current) + u = (current == v_start).long() + possible_next = ( + u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0 + ).long() + u = u[:, 1:-1, 1:-1] + reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * ( + (current == v_path).sum((1, 2)) == 0 + ) + current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + ( + v_start - v_path + ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path)) + still_ok *= (current == v_start).sum((1, 2)) <= 1 + + return still_ok * reached + + +###################################################################### + + +def create_maze_data( + nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x +): + mazes = torch.empty(nb, height, width, dtype=torch.int64) + paths = torch.empty(nb, height, width, dtype=torch.int64) + policies = torch.empty(nb, 4, height, width) + + for n in progress_bar(range(nb)): + maze = create_maze(height, width, nb_walls) + i = (maze == v_empty).nonzero() + while True: + start, goal = i[torch.randperm(i.size(0))[:2]] + if (start - goal).abs().sum() >= dist_min: + break + start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1] + + policy = compute_policy(maze, goal_i, goal_j) + path = maze.clone() + mark_path(path, start_i, start_j, goal_i, goal_j, policy) + maze[start_i, start_j] = v_start + maze[goal_i, goal_j] = v_goal + path[start_i, start_j] = v_start + path[goal_i, goal_j] = v_goal + + mazes[n] = maze + paths[n] = path + policies[n] = policy + + return mazes, paths, policies + + +###################################################################### + + +def save_image( + name, + mazes, + target_paths=None, + predicted_paths=None, + score_paths=None, + score_truth=None, + path_correct=None, +): + colors = torch.tensor( + [ + [255, 255, 255], # empty + [0, 0, 0], # wall + [0, 255, 0], # start + [127, 127, 255], # goal + [255, 0, 0], # path + ] + ) + + mazes = mazes.cpu() + + c_mazes = ( + colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) + ) + + if score_truth is not None: + score_truth = score_truth.cpu() + c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1) + c_score_truth = ( + c_score_truth * colors[4].reshape(1, 3, 1, 1) + + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1) + ).long() + c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + ( + mazes.unsqueeze(1) == v_empty + ) * c_score_truth + + imgs = c_mazes.unsqueeze(1) + + if target_paths is not None: + target_paths = target_paths.cpu() + + c_target_paths = ( + colors[target_paths.reshape(-1)] + .reshape(target_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + + imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1) + + if predicted_paths is not None: + predicted_paths = predicted_paths.cpu() + c_predicted_paths = ( + colors[predicted_paths.reshape(-1)] + .reshape(predicted_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1) + + if score_paths is not None: + score_paths = score_paths.cpu() + c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1) + c_score_paths = ( + c_score_paths * colors[4].reshape(1, 3, 1, 1) + + (1 - c_score_paths) * colors[0].reshape(1, 3, 1, 1) + ).long() + c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * ( + mazes.unsqueeze(1) != v_empty + ) + imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1) + + # NxKxCxHxW + if path_correct is None: + path_correct = torch.zeros(imgs.size(0)) <= 1 + path_correct = path_correct.cpu().long().view(-1, 1, 1, 1) + img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor( + [255, 0, 0] + ).view(1, -1, 1, 1) * (1 - path_correct) + img = img.expand( + -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) + ).clone() + for k in range(imgs.size(1)): + img[ + :, + :, + 1 : 1 + imgs.size(3), + 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4), + ] = imgs[:, k] + + img = img.float() / 255.0 + + torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256) + + +###################################################################### + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mazes, paths = create_maze_data(8) + mazes, paths = mazes.to(device), paths.to(device) + save_image("test.png", mazes, paths, paths) + print(path_correctness(mazes, paths)) + +###################################################################### -- 2.20.1