Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 26 Jul 2023 21:32:24 +0000 (11:32 -1000)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 26 Jul 2023 21:32:24 +0000 (11:32 -1000)
do_all.sh [new file with mode: 0755]
picoclvr.py
problems.py

diff --git a/do_all.sh b/do_all.sh
new file mode 100755 (executable)
index 0000000..76f1982
--- /dev/null
+++ b/do_all.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+##################################################################
+# START_IP_HEADER                                                #
+#                                                                #
+# Written by Francois Fleuret                                    #
+# Contact <francois.fleuret@unige.ch> for comments & bug reports #
+#                                                                #
+# END_IP_HEADER                                                  #
+##################################################################
+
+# set -e
+# set -o pipefail
+
+#prefix="--nb_train_samples=1000 --nb_test_samples=100 --batch_size=25 --nb_epochs=2 --max_percents_of_test_in_train=-1 --model=17K"
+prefix="--nb_epochs=2"
+
+for task in byheart learnop guessop twotargets addition picoclvr maze snake stack expr rpl
+do
+    [[ ! -d results_${task} ]] && ./main.py ${prefix} --task=${task}
+done
+
index 5da3943..0cd3062 100755 (executable)
@@ -5,6 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import math
 import torch, torchvision
 import torch.nn.functional as F
 
@@ -201,7 +202,12 @@ def generate(
     descr = []
 
     for n in range(nb):
-        nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+        # we want uniform over the combinations of 1 to max_nb_squares
+        # pixels of nb_colors
+        logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        nb_squares = dist.sample((1,)) + 1
+        # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
         square_position = torch.randperm(height * width)[:nb_squares]
 
         # color 0 is white and reserved for the background
index 5686404..2c8602c 100755 (executable)
@@ -87,7 +87,7 @@ class ProblemByHeart(Problem):
 
 
 class ProblemLearnOperator(Problem):
-    def __init__(self, nb_operators=100, len_source=5, len_result=8):
+    def __init__(self, nb_operators=100, len_source=6, len_result=9):
         self.len_source = len_source
         self.len_result = len_result
         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1