From 8ea809c43242d3a2e063692105919a86c3f6fe6b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Feb 2024 16:47:27 +0100 Subject: [PATCH] Update. --- main.py | 1 + tasks.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 9198edc..958dfc7 100755 --- a/main.py +++ b/main.py @@ -413,6 +413,7 @@ if args.task == "file": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, + shuffle=True, device=device, ) args.max_percents_of_test_in_train = 0 diff --git a/tasks.py b/tasks.py index e5d3a7e..d21e264 100755 --- a/tasks.py +++ b/tasks.py @@ -71,7 +71,7 @@ class Task: class TaskFromFile(Task): - def tensorize(self, pairs): + def tensorize(self, pairs, shuffle): len_max = max([len(x[0]) for x in pairs]) input = torch.cat( @@ -98,6 +98,12 @@ class TaskFromFile(Task): 0, ).to("cpu") + if shuffle: + print("SHUFFLING!") + i = torch.randperm(input.size(0)) + input = input[i].contiguous() + pred_mask = pred_mask[i].contiguous() + return input, pred_mask # trim all the tensors in the tuple z to remove as much token from @@ -122,6 +128,7 @@ class TaskFromFile(Task): nb_train_samples, nb_test_samples, batch_size, + shuffle=False, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -156,8 +163,12 @@ class TaskFromFile(Task): self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) - self.train_input, self.train_pred_masks = self.tensorize(train_pairs) - self.test_input, self.test_pred_masks = self.tensorize(test_pairs) + self.train_input, self.train_pred_masks = self.tensorize( + train_pairs, shuffle=shuffle + ) + self.test_input, self.test_pred_masks = self.tensorize( + test_pairs, shuffle=shuffle + ) def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} -- 2.20.1