Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 22:59:24 +0000 (00:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 22:59:24 +0000 (00:59 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 0d4930d..19f918c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -86,7 +86,7 @@ parser.add_argument("--sandbox_level", type=int, default=0)
 
 parser.add_argument("--sandbox_levels_nb_items", type=int, default=25)
 
-parser.add_argument("--sandbox_levels_len_source", type=int, default=5)
+parser.add_argument("--sandbox_levels_len_source", type=int, default=6)
 
 parser.add_argument("--sandbox_levels_len_result", type=int, default=8)
 
@@ -163,9 +163,9 @@ if args.result_dir is None:
 
 default_args = {
     "sandbox": {
-        "nb_epochs": 10,
+        "nb_epochs": 50,
         "batch_size": 25,
-        "nb_train_samples": 25000,
+        "nb_train_samples": 100000,
         "nb_test_samples": 10000,
     },
     "picoclvr": {
index e7c2f75..c5418b4 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -104,7 +104,8 @@ class ProblemLevel1(Problem):
             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
         ) % 10
         marker1 = torch.full((nb, 1), 10)
-        source = torch.randint(10, (nb, self.len_source))
+        # source = torch.randint(10, (nb, self.len_source))
+        source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
         marker2 = torch.full((nb, 1), 11)
         result = operators.bmm(source[:, :, None]).squeeze(-1)
         print(f"{nb_operators.dtype=} {marker1.dtype=}")
@@ -128,7 +129,8 @@ class ProblemLevel2(Problem):
             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
             num_classes=self.len_source,
         )
-        source1 = torch.randint(10, (nb, self.len_source))
+        source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+        # source1 = torch.randint(10, (nb, self.len_source))
         marker1 = torch.full((nb, 1), 10)
         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
         marker2 = torch.full((nb, 1), 11)