From 1eef58fd084437bbcd2041b946b468615e203dd8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Feb 2024 14:56:02 +0100 Subject: [PATCH] Update. --- main.py | 11 +++++++---- tasks.py | 52 ++++++++++++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index 55f2c2f..9198edc 100755 --- a/main.py +++ b/main.py @@ -89,7 +89,9 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # filetask -parser.add_argument("--filetask_file", type=str, default=None) +parser.add_argument("--filetask_train_file", type=str, default=None) + +parser.add_argument("--filetask_test_file", type=str, default=None) ############################## # rpl options @@ -403,10 +405,11 @@ picoclvr_pruner_eval = ( if args.task == "file": assert ( - args.filetask_file is not None - ), "You have to specify the task file with --filetask_file " + args.filetask_train_file is not None and args.filetask_test_file is not None + ), "You have to specify the task train and test files" task = tasks.TaskFromFile( - args.filetask_file, + args.filetask_train_file, + args.filetask_test_file, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, diff --git a/tasks.py b/tasks.py index 1ea3b5d..e5d3a7e 100755 --- a/tasks.py +++ b/tasks.py @@ -117,7 +117,8 @@ class TaskFromFile(Task): def __init__( self, - filename, + train_filename, + test_filename, nb_train_samples, nb_test_samples, batch_size, @@ -126,26 +127,37 @@ class TaskFromFile(Task): self.batch_size = batch_size self.device = device - pairs = [] - with open(filename, "r") as f: - for _ in range(nb_train_samples + nb_test_samples): - sequence = f.readline().strip() - pred_mask = f.readline().strip() - assert len(sequence) == len(pred_mask) - assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}" - pairs.append((sequence, pred_mask)) - - symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"])) + def read_file(filename, nb=-1): + pairs = [] + with open(filename, "r") as f: + while True: + sequence = f.readline().strip() + if not sequence: + break + pred_mask = f.readline().strip() + assert len(sequence) == len(pred_mask) + assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}" + pairs.append((sequence, pred_mask)) + if len(pairs) == nb: + break + + if nb > 0: + pairs = pairs[:nb] + assert len(pairs) == nb + + return pairs + + train_pairs = read_file(train_filename, nb_train_samples) + test_pairs = read_file(test_filename, nb_test_samples) + + symbols = ["#"] + list( + set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"]) + ) self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) - self.train_input, self.train_pred_masks = self.tensorize( - pairs[:nb_train_samples] - ) - self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:]) - - assert self.train_input.size(0) == nb_train_samples - assert self.test_input.size(0) == nb_test_samples + self.train_input, self.train_pred_masks = self.tensorize(train_pairs) + self.test_input, self.test_pred_masks = self.tensorize(test_pairs) def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -176,7 +188,7 @@ class TaskFromFile(Task): logger(f"----------------------------------------------------------") - for e in self.tensor2str(result[:10]): + for e in self.tensor2str(result[:50]): logger(f"test_before {e}") masked_inplace_autoregression( @@ -190,7 +202,7 @@ class TaskFromFile(Task): logger(f"----------------------------------------------------------") - for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])): + for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])): logger(f"test_after {e}") logger(f"correct {c}") -- 2.20.1