From d8ec2ebf14b7299b246456a440ff15e97cfae472 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 22 Jun 2023 16:18:17 +0200 Subject: [PATCH] Update. --- main.py | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index db982ca..0144817 100755 --- a/main.py +++ b/main.py @@ -31,9 +31,11 @@ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument("--task", type=str, default="picoclvr") +parser.add_argument( + "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake" +) -parser.add_argument("--log_filename", type=str, default="train.log") +parser.add_argument("--log_filename", type=str, default="train.log", help=" ") parser.add_argument("--result_dir", type=str, default="results_default") @@ -619,6 +621,9 @@ class TaskMaze(Task): def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 + count = torch.zeros( + self.width * self.height, self.width * self.height, device=self.device + ) for input in task.batches(split, nb_to_use): result = input.clone() ar_mask = result.new_zeros(result.size()) @@ -628,30 +633,57 @@ class TaskMaze(Task): model, self.batch_size, result, ar_mask, device=self.device ) mazes, paths = self.seq2map(result) - nb_correct += maze.path_correctness(mazes, paths).long().sum() + path_correctness = maze.path_correctness(mazes, paths) + nb_correct += path_correctness.long().sum() nb_total += mazes.size(0) - return nb_total, nb_correct + optimal_path_lengths = ( + (input[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + predicted_path_lengths = ( + (result[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + optimal_path_lengths = optimal_path_lengths[path_correctness] + predicted_path_lengths = predicted_path_lengths[path_correctness] + count[optimal_path_lengths, predicted_path_lengths] += 1 + + if count.max() == 0: + count = None + else: + count = count[ + : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1 + ] + + return nb_total, nb_correct, count def produce_results(self, n_epoch, model): with torch.autograd.no_grad(): t = model.training model.eval() - train_nb_total, train_nb_correct = self.compute_error( + train_nb_total, train_nb_correct, count = self.compute_error( model, "train", nb_to_use=1000 ) log_string( f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = self.compute_error( + test_nb_total, test_nb_correct, count = self.compute_error( model, "test", nb_to_use=1000 ) log_string( f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + if count is not None: + with open( + os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" + ) as f: + for i in range(count.size(0)): + for j in range(count.size(1)): + eol = " " if j < count.size(1) - 1 else "\n" + f.write(f"{count[i,j]}{eol}") + input = self.test_input[:48] result = input.clone() ar_mask = result.new_zeros(result.size()) -- 2.20.1