Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 20 Jul 2023 12:11:54 +0000 (14:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 20 Jul 2023 12:11:54 +0000 (14:11 +0200)
rpl.py
tasks.py

diff --git a/rpl.py b/rpl.py
index 7c1c96e..8d31efe 100755 (executable)
--- a/rpl.py
+++ b/rpl.py
@@ -58,7 +58,9 @@ rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"]
 ######################################################################
 
 
-def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5):
+def generate(
+    nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5
+):
     prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item()
 
     while True:
@@ -77,7 +79,10 @@ def generate(nb_starting_values=3, max_input=9, prog_len=6, nb_runs=5):
 
         result = result + ["<prog>"] + prog
         result = result + ["<end>"]
-        if no_empty_stack:
+
+        if no_empty_stack and (
+            nb_result_values_max is None or len(result_stack) <= nb_result_values_max
+        ):
             break
 
     return result
index 889d4a9..0827a44 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1070,6 +1070,7 @@ class RPL(Task):
         train_sequences = [
             rpl.generate(
                 nb_starting_values=nb_starting_values,
+                nb_result_values_max=4 * nb_starting_values,
                 max_input=max_input,
                 prog_len=prog_len,
                 nb_runs=nb_runs,
@@ -1080,6 +1081,7 @@ class RPL(Task):
         test_sequences = [
             rpl.generate(
                 nb_starting_values=nb_starting_values,
+                nb_result_values_max=4 * nb_starting_values,
                 max_input=max_input,
                 prog_len=prog_len,
                 nb_runs=nb_runs,