From 0f4c86c0e7730db4147f136df5aeb5528fc943a0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Oct 2023 09:42:48 +0200 Subject: [PATCH] Update. --- README.txt | 5 +++ main.py | 23 ++++++++++++- problems.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++--- tasks.py | 11 +++--- 4 files changed, 127 insertions(+), 9 deletions(-) diff --git a/README.txt b/README.txt index d4740f3..489a792 100644 --- a/README.txt +++ b/README.txt @@ -1,3 +1,8 @@ +18.10.2023 + +./main.py --task=qmlp --model=352M --nb_train_samples=250000 --result_dir=results_qmlp_352M --batch_size=2 + +~11h per epoch on 3090 Ti ====================================================================== For the stack experiment: diff --git a/main.py b/main.py index d961301..4a46fe6 100755 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ parser.add_argument( "--task", type=str, default="twotargets", - help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp", + help="byheart, learnop, guessop, twocuts, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -159,6 +159,11 @@ parser.add_argument("--expr_result_max", type=int, default=99) parser.add_argument("--expr_input_file", type=str, default=None) +############################## +# Misc + +parser.add_argument("--twocuts_no_global", action="store_true", default=False) + ###################################################################### args = parser.parse_args() @@ -249,6 +254,12 @@ default_task_args = { "nb_train_samples": 50000, "nb_test_samples": 10000, }, + "twocuts": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 100000, + "nb_test_samples": 10000, + }, "mnist": { "model": "37M", "batch_size": 10, @@ -403,6 +414,16 @@ elif args.task == "twotargets": device=device, ) +elif args.task == "twocuts": + task = tasks.SandBox( + problem=problems.ProblemTwoCuts(global_constraint = not args.twocuts_no_global), + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + logger=log_string, + device=device, + ) + elif args.task == "addition": task = tasks.SandBox( problem=problems.ProblemAddition(), diff --git a/problems.py b/problems.py index 2c8602c..68a46b3 100755 --- a/problems.py +++ b/problems.py @@ -17,6 +17,96 @@ class Problem: def seq2str(self, seq): return "[NOT IMPLEMENTED]" + def compute_nb_correct(self, input, ar_mask, result): + nb_total = ar_mask.sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() + return nb_total, nb_correct + +#################### + + +class ProblemTwoCuts(Problem): + def __init__(self, len_total=50, nb_values=100, global_constraint=True): + self.len_total = len_total + self.nb_values = nb_values + self.global_constraint = global_constraint + + def generate_sequences_internal(self, nb): + return u,v,a,b,c + + def generate_sequences(self,nb): + + u = torch.randint(self.len_total, (nb,)) + v = torch.randint(self.len_total, (nb,)) + + a = torch.randint(self.nb_values, (nb,)) + b = torch.randint(self.nb_values, (nb,)) + c = torch.randint(self.nb_values, (nb,)) + + while True: + to_compute = torch.logical_or(u>=v-self.len_total//10,u=v).long().sum() == 0 + assert (a==b).long().sum() == 0 + assert (a==c).long().sum() == 0 + assert (c==b).long().sum() == 0 + + t = torch.arange(self.len_total) + seq = (t[None,:] < u[:,None]).long() * a[:,None] + \ + (t[None,:] >= u[:,None]).long() * (t[None,:] < v[:,None]).long() * b[:,None] + \ + (t[None,:] >= v[:,None]).long() * c[:,None] + + return seq,seq.new_full(seq.size(), 1, dtype=torch.int64) + + def compute_nb_correct(self, input, ar_mask, result): + nb_total = result.size(0) + nb_correct = 0 + i = torch.arange(result.size(1), device=result.device) + + for k in range(nb_total): + s = result[k] + a = s[0] + uu = (s != a).nonzero() + if uu.size(0) > 0: + u = uu.min() + b = s[u] + vv = torch.logical_and(s != b, i >= u).nonzero() + if vv.size(0) > 0: + v = vv.min() + c = s[v] + ww = torch.logical_and(s != c, i >= v).nonzero() + if ww.size(0) == 0: + if not self.global_constraint or (a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total == self.nb_values//2: + nb_correct += 1 + + return nb_total, nb_correct + + def seq2str(self, seq): + return " ".join( [ f"{x:02d}" for x in seq ] ) #################### @@ -197,7 +287,6 @@ class ProblemAddition(Problem): if __name__ == "__main__": - p = ProblemTwoTargets(12, 4) - s, m = p.generate_sequences(10) - for x in s: - print(p.seq2str(x)) + p = ProblemTwoCuts(12) + s, m = p.generate_sequences(10000) + print(p.compute_nb_correct(None, None, s)) diff --git a/tasks.py b/tasks.py index b33dee2..0858282 100755 --- a/tasks.py +++ b/tasks.py @@ -110,13 +110,14 @@ class SandBox(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + # A bit of paranoia never hurts assert ( self.nb_codes <= max_nb_codes and self.train_input.min() >= 0 and self.test_input.min() >= 0 - and tuple(self.train_ar_mask.unique()) == (0, 1) - and tuple(self.test_ar_mask.unique()) == (0, 1) + and tuple(x.item() for x in self.train_ar_mask.unique()) in { (0,), (1,), (0,1) } + and tuple(x.item() for x in self.test_ar_mask.unique()) in { (0,), (1,), (0,1) } ) def batches(self, split="train", nb_to_use=-1, desc=None): @@ -160,8 +161,10 @@ class SandBox(Task): f" {n_epoch} ground truth {self.problem.seq2str(st)}" ) - nb_total = ar_mask.sum().item() - nb_correct = ((result == input).long() * ar_mask).sum().item() + nb_total, nb_correct = self.problem.compute_nb_correct(input, ar_mask, result) + + # nb_total = ar_mask.sum().item() + # nb_correct = ((result == input).long() * ar_mask).sum().item() return nb_total, nb_correct -- 2.20.1