Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 1 Jul 2023 18:50:15 +0000 (20:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 1 Jul 2023 18:50:15 +0000 (20:50 +0200)
main.py
stack.py

diff --git a/main.py b/main.py
index 0323d02..14b1bc3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -109,7 +109,7 @@ parser.add_argument("--snake_length", type=int, default=200)
 ##############################
 # Snake options
 
-parser.add_argument("--stack_nb_steps", type=int, default=25)
+parser.add_argument("--stack_nb_steps", type=int, default=100)
 
 parser.add_argument("--stack_nb_stacks", type=int, default=1)
 
@@ -166,9 +166,9 @@ default_args = {
         "nb_test_samples": 10000,
     },
     "stack": {
-        "nb_epochs": 25,
+        "nb_epochs": 5,
         "batch_size": 25,
-        "nb_train_samples": 10000,
+        "nb_train_samples": 100000,
         "nb_test_samples": 1000,
     },
 }
@@ -892,6 +892,13 @@ class TaskStack(Task):
             nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
         )
 
+        mask = self.test_input.clone()
+        stack.remove_poped_values(mask,self.nb_stacks)
+        mask=(mask!=self.test_input)
+        counts = self.test_stack_counts.flatten()[mask.flatten()]
+        counts=F.one_hot(counts).sum(0)
+        log_string(f"stack_count {counts}")
+
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
index 312b39f..ba452aa 100755 (executable)
--- a/stack.py
+++ b/stack.py
@@ -45,12 +45,13 @@ def remove_poped_values(seq, nb_stacks):
     seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
 
 
-def seq_to_str(seq, recorded_stack_counts=None):
+def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None):
     assert seq.size(0) % 2 == 0
     s = ""
     for t in range(seq.size(0) // 2):
-        op = seq[2 * t]
-        op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}"
+        n_op = seq[2 * t]
+        op = f"POP" if n_op % 2 == 1 else f"PSH"
+        if show_stack_nb: op+=f"_{n_op//2}"
         if seq[2 * t + 1] == -1:
             val = "?"
         else:
@@ -71,13 +72,15 @@ if __name__ == "__main__":
         nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
     )
 
+    print("-- TRAIN -----------------------------")
+
     for n in range(min(10, seq.size(0))):
         # print(seq_to_str(seq[n], recorded_stack_counts[n]))
-        print(seq_to_str(seq[n]))
+        print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))
 
-    print("--------------------------------------")
+    print("-- TEST ------------------------------")
 
     remove_poped_values(seq, nb_stacks)
 
     for n in range(min(10, seq.size(0))):
-        print(seq_to_str(seq[n]))
+        print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))