Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Oct 2023 07:42:48 +0000 (09:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Oct 2023 07:42:48 +0000 (09:42 +0200)
README.txt
main.py
problems.py
tasks.py

index d4740f3..489a792 100644 (file)
@@ -1,3 +1,8 @@
+18.10.2023
+
+./main.py --task=qmlp --model=352M --nb_train_samples=250000 --result_dir=results_qmlp_352M --batch_size=2
+
+~11h per epoch on 3090 Ti
 
 ======================================================================
 For the stack experiment:
diff --git a/main.py b/main.py
index d961301..4a46fe6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -33,7 +33,7 @@ parser.add_argument(
     "--task",
     type=str,
     default="twotargets",
-    help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+    help="byheart, learnop, guessop, twocuts, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -159,6 +159,11 @@ parser.add_argument("--expr_result_max", type=int, default=99)
 
 parser.add_argument("--expr_input_file", type=str, default=None)
 
+##############################
+# Misc
+
+parser.add_argument("--twocuts_no_global", action="store_true", default=False)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -249,6 +254,12 @@ default_task_args = {
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
+    "twocuts": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 100000,
+        "nb_test_samples": 10000,
+    },
     "mnist": {
         "model": "37M",
         "batch_size": 10,
@@ -403,6 +414,16 @@ elif args.task == "twotargets":
         device=device,
     )
 
+elif args.task == "twocuts":
+    task = tasks.SandBox(
+        problem=problems.ProblemTwoCuts(global_constraint = not args.twocuts_no_global),
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device,
+    )
+
 elif args.task == "addition":
     task = tasks.SandBox(
         problem=problems.ProblemAddition(),
index 2c8602c..68a46b3 100755 (executable)
@@ -17,6 +17,96 @@ class Problem:
     def seq2str(self, seq):
         return "[NOT IMPLEMENTED]"
 
+    def compute_nb_correct(self, input, ar_mask, result):
+        nb_total = ar_mask.sum().item()
+        nb_correct = ((result == input).long() * ar_mask).sum().item()
+        return nb_total, nb_correct
+
+####################
+
+
+class ProblemTwoCuts(Problem):
+    def __init__(self, len_total=50, nb_values=100, global_constraint=True):
+        self.len_total = len_total
+        self.nb_values = nb_values
+        self.global_constraint = global_constraint
+
+    def generate_sequences_internal(self, nb):
+        return u,v,a,b,c
+
+    def generate_sequences(self,nb):
+
+        u = torch.randint(self.len_total, (nb,))
+        v = torch.randint(self.len_total, (nb,))
+
+        a = torch.randint(self.nb_values, (nb,))
+        b = torch.randint(self.nb_values, (nb,))
+        c = torch.randint(self.nb_values, (nb,))
+
+        while True:
+            to_compute = torch.logical_or(u>=v-self.len_total//10,u<v-self.len_total//5)
+            to_compute =torch.logical_or(to_compute, u == 0)
+            to_compute =torch.logical_or(to_compute, v == self.len_total)
+            n = to_compute.long().sum()
+            if n == 0:
+                break
+            else:
+                u[to_compute] = torch.randint(self.len_total, (n,))
+                v[to_compute] = torch.randint(self.len_total, (n,))
+
+        while True:
+            to_compute = a==b
+            to_compute = torch.logical_or(to_compute,b==c)
+            to_compute = torch.logical_or(to_compute,a==c)
+
+            if self.global_constraint:
+                to_compute = torch.logical_or(to_compute,(a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total != self.nb_values//2)
+
+            n = to_compute.long().sum()
+            if n == 0:
+                break
+            else:
+                a[to_compute] = torch.randint(self.nb_values, (n,))
+                b[to_compute] = torch.randint(self.nb_values, (n,))
+                c[to_compute] = torch.randint(self.nb_values, (n,))
+
+        assert (u>=v).long().sum() == 0
+        assert (a==b).long().sum() == 0
+        assert (a==c).long().sum() == 0
+        assert (c==b).long().sum() == 0
+
+        t = torch.arange(self.len_total)
+        seq = (t[None,:] < u[:,None]).long() * a[:,None] + \
+            (t[None,:] >= u[:,None]).long() * (t[None,:] < v[:,None]).long() * b[:,None] + \
+            (t[None,:] >= v[:,None]).long() * c[:,None]
+
+        return seq,seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+    def compute_nb_correct(self, input, ar_mask, result):
+        nb_total = result.size(0)
+        nb_correct = 0
+        i = torch.arange(result.size(1), device=result.device)
+
+        for k in range(nb_total):
+            s = result[k]
+            a = s[0]
+            uu = (s != a).nonzero()
+            if uu.size(0) > 0:
+                u = uu.min()
+                b = s[u]
+                vv = torch.logical_and(s != b, i >= u).nonzero()
+                if vv.size(0) > 0:
+                    v = vv.min()
+                    c = s[v]
+                    ww = torch.logical_and(s != c, i >= v).nonzero()
+                    if ww.size(0) == 0:
+                        if not self.global_constraint or (a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total == self.nb_values//2:
+                            nb_correct += 1
+
+        return nb_total, nb_correct
+
+    def seq2str(self, seq):
+        return " ".join( [ f"{x:02d}" for x in seq ] )
 
 ####################
 
@@ -197,7 +287,6 @@ class ProblemAddition(Problem):
 
 
 if __name__ == "__main__":
-    p = ProblemTwoTargets(12, 4)
-    s, m = p.generate_sequences(10)
-    for x in s:
-        print(p.seq2str(x))
+    p = ProblemTwoCuts(12)
+    s, m = p.generate_sequences(10000)
+    print(p.compute_nb_correct(None, None, s))
index b33dee2..0858282 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -110,13 +110,14 @@ class SandBox(Task):
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
+
         # A bit of paranoia never hurts
         assert (
             self.nb_codes <= max_nb_codes
             and self.train_input.min() >= 0
             and self.test_input.min() >= 0
-            and tuple(self.train_ar_mask.unique()) == (0, 1)
-            and tuple(self.test_ar_mask.unique()) == (0, 1)
+            and tuple(x.item() for x in self.train_ar_mask.unique()) in { (0,), (1,), (0,1) }
+            and tuple(x.item() for x in self.test_ar_mask.unique()) in { (0,), (1,), (0,1) }
         )
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
@@ -160,8 +161,10 @@ class SandBox(Task):
                         f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
                     )
 
-            nb_total = ar_mask.sum().item()
-            nb_correct = ((result == input).long() * ar_mask).sum().item()
+            nb_total, nb_correct = self.problem.compute_nb_correct(input, ar_mask, result)
+
+            # nb_total = ar_mask.sum().item()
+            # nb_correct = ((result == input).long() * ar_mask).sum().item()
 
             return nb_total, nb_correct