Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 1 Jul 2023 17:42:47 +0000 (19:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 1 Jul 2023 17:42:47 +0000 (19:42 +0200)
main.py
stack.py

diff --git a/main.py b/main.py
index 45bddb7..0323d02 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -32,7 +32,7 @@ parser = argparse.ArgumentParser(
 )
 
 parser.add_argument(
-    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake"
+    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack"
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -106,6 +106,15 @@ parser.add_argument("--snake_nb_colors", type=int, default=5)
 
 parser.add_argument("--snake_length", type=int, default=200)
 
+##############################
+# Snake options
+
+parser.add_argument("--stack_nb_steps", type=int, default=25)
+
+parser.add_argument("--stack_nb_stacks", type=int, default=1)
+
+parser.add_argument("--stack_nb_values", type=int, default=10)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -135,18 +144,32 @@ default_args = {
     "picoclvr": {
         "nb_epochs": 25,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "mnist": {
         "nb_epochs": 25,
         "batch_size": 10,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "maze": {
         "nb_epochs": 25,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "snake": {
         "nb_epochs": 5,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "stack": {
+        "nb_epochs": 25,
+        "batch_size": 25,
+        "nb_train_samples": 10000,
+        "nb_test_samples": 1000,
     },
 }
 
@@ -841,6 +864,86 @@ class TaskSnake(Task):
 ######################################################################
 
 
+import stack
+
+
+class TaskStack(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        nb_steps,
+        nb_stacks,
+        nb_values,
+        device=torch.device("cpu"),
+    ):
+        self.batch_size = batch_size
+        self.nb_steps = nb_steps
+        self.nb_stacks = nb_stacks
+        self.nb_values = nb_values
+        self.device = device
+
+        self.train_input, self.train_stack_counts = stack.generate_sequences(
+            nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+        )
+
+        self.test_input, self.test_stack_counts = stack.generate_sequences(
+            nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+        )
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input):
+                result = input.clone()
+                stack.remove_poped_values(result,self.nb_stacks)
+                ar_mask = (result != input).long()
+                result *= 1 - ar_mask
+
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = ar_mask.sum()
+
+                nb_correct = (
+                    (result == input).long() * ar_mask
+                ).sum()
+
+                return nb_total, nb_correct
+
+            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
+
+######################################################################
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -902,6 +1005,17 @@ elif args.task == "snake":
         device=device,
     )
 
+elif args.task == "stack":
+    task = TaskStack(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        nb_steps = args.stack_nb_steps,
+        nb_stacks = args.stack_nb_stacks,
+        nb_values = args.stack_nb_values,
+        device=device,
+    )
+
 else:
     raise ValueError(f"Unknown task {args.task}")
 
index dc494bb..312b39f 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -13,44 +13,52 @@ import torch, torchvision
 # CODE_VAL=val + 2 * nb_stacks
 
 
-def generate(nb, nb_steps, nb_stacks, nb_values):
+def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")):
     stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
-    stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64)
+    stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
     k = torch.arange(nb)
     result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
-    depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
+    recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
 
     for t in range(nb_steps):
         op = torch.randint(2, (nb,))
         st = torch.randint(nb_stacks, (nb,))
-        op = op * (stack_pointers[k, st] > 0)
+        op = op * (stack_counts[k, st] > 0)
         val_push = torch.randint(nb_values, (nb,))
         val_pop = stack[
             k,
             st,
-            (stack_pointers[k, st] - 1).clamp(min=0),
+            (stack_counts[k, st] - 1).clamp(min=0),
         ]
-        stack[k, st, stack_pointers[k, st]] = val_push
-        depth_counts[:, 2 * t + 1] = stack_pointers[k, st]
-        stack_pointers[k[op == 0], st[op == 0]] += 1
-        stack_pointers[k[op == 1], st[op == 1]] -= 1
+        stack[k, st, stack_counts[k, st]] = val_push
+        recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st]
+        stack_counts[k[op == 0], st[op == 0]] += 1
+        stack_counts[k[op == 1], st[op == 1]] -= 1
         result[:, 2 * t] = st * 2 + op
         result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
 
-    return result, depth_counts
+    return result.to(device), recorded_stack_counts.to(device)
 
 
-def seq_to_str(seq, depth_counts=None):
+def remove_poped_values(seq, nb_stacks):
+    m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
+    seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
+
+
+def seq_to_str(seq, recorded_stack_counts=None):
     assert seq.size(0) % 2 == 0
     s = ""
     for t in range(seq.size(0) // 2):
         op = seq[2 * t]
-        op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}"
-        val = seq[2 * t + 1] - 2 * nb_stacks
+        op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}"
+        if seq[2 * t + 1] == -1:
+            val = "?"
+        else:
+            val = seq[2 * t + 1] - 2 * nb_stacks
         if t > 0:
             s += " "
-        if depth_counts is not None:
-            s += f"[{depth_counts[2*t+1]}] "
+        if recorded_stack_counts is not None:
+            s += f"[{recorded_stack_counts[2*t+1]}] "
         s += f"{op} {val}"
     return s
 
@@ -59,9 +67,17 @@ def seq_to_str(seq, depth_counts=None):
 
 if __name__ == "__main__":
     nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
-    seq, depth_counts = generate(
+    seq, recorded_stack_counts = generate_sequences(
         nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
     )
 
     for n in range(min(10, seq.size(0))):
-        print(seq_to_str(seq[n], depth_counts[n]))
+        # print(seq_to_str(seq[n], recorded_stack_counts[n]))
+        print(seq_to_str(seq[n]))
+
+    print("--------------------------------------")
+
+    remove_poped_values(seq, nb_stacks)
+
+    for n in range(min(10, seq.size(0))):
+        print(seq_to_str(seq[n]))