Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 10:57:36 +0000 (12:57 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 10:57:36 +0000 (12:57 +0200)
main.py
tasks.py
world.py

diff --git a/main.py b/main.py
index 305bd3c..69ee58f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -34,8 +34,8 @@ parser = argparse.ArgumentParser(
 parser.add_argument(
     "--task",
     type=str,
-    default="picoclvr",
-    help="picoclvr, mnist, maze, snake, stack, expr, world",
+    default="sandbox",
+    help="sandbox, picoclvr, mnist, maze, snake, stack, expr, world",
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -150,6 +150,12 @@ if args.result_dir is None:
 ######################################################################
 
 default_args = {
+    "sandbox": {
+        "nb_epochs": 10,
+        "batch_size": 25,
+        "nb_train_samples": 25000,
+        "nb_test_samples": 10000,
+    },
     "picoclvr": {
         "nb_epochs": 25,
         "batch_size": 25,
@@ -189,7 +195,7 @@ default_args = {
     "world": {
         "nb_epochs": 10,
         "batch_size": 25,
-        "nb_train_samples": 125000,
+        "nb_train_samples": 25000,
         "nb_test_samples": 1000,
     },
 }
@@ -257,7 +263,16 @@ picoclvr_pruner_eval = (
 
 ######################################################################
 
-if args.task == "picoclvr":
+if args.task == "sandbox":
+    task = tasks.SandBox(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        logger=log_string,
+        device=device,
+    )
+
+elif args.task == "picoclvr":
     task = tasks.PicoCLVR(
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
index f8fb9b9..8b57cb2 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -60,6 +60,69 @@ class Task:
         pass
 
 
+######################################################################
+
+
+class Problem:
+    def generate(nb):
+        pass
+
+    def perf(seq, logger):
+        pass
+
+
+class ProblemByheart(Problem):
+    def __init__(self):
+        pass
+
+
+class SandBox(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+
+        def generate_sequences(nb_samples):
+            problem_indexes = torch.randint(len(problems), (nb_samples,))
+            nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+            print(f"{nb_samples_per_problem}")
+
+        self.train_input = generate_sequences(nb_train_samples)
+        self.test_input = generate_sequences(nb_test_samples)
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        # logger(
+        # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        # )
+        pass
+
+
 ######################################################################
 
 import picoclvr
@@ -108,6 +171,8 @@ class PicoCLVR(Task):
         pruner_train=None,
         pruner_eval=None,
     ):
+        super().__init__()
+
         def generate_descr(nb, cache_suffix, pruner):
             return picoclvr.generate(
                 nb,
@@ -296,6 +361,8 @@ class MNIST(Task):
     def __init__(
         self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
     ):
+        super().__init__()
+
         self.nb_train_samples = (nb_train_samples,)
         self.nb_test_samples = (nb_test_samples,)
         self.batch_size = batch_size
@@ -366,6 +433,8 @@ class Maze(Task):
         nb_walls,
         device=torch.device("cpu"),
     ):
+        super().__init__()
+
         self.batch_size = batch_size
         self.height = height
         self.width = width
@@ -537,6 +606,8 @@ class Snake(Task):
         prompt_length,
         device=torch.device("cpu"),
     ):
+        super().__init__()
+
         self.batch_size = batch_size
         self.height = height
         self.width = width
@@ -635,6 +706,8 @@ class Stack(Task):
         fraction_values_for_train=None,
         device=torch.device("cpu"),
     ):
+        super().__init__()
+
         self.batch_size = batch_size
         self.nb_steps = nb_steps
         self.nb_stacks = nb_stacks
@@ -782,6 +855,8 @@ class Expr(Task):
         batch_size,
         device=torch.device("cpu"),
     ):
+        super().__init__()
+
         self.batch_size = batch_size
         self.device = device
 
@@ -961,6 +1036,8 @@ class World(Task):
         device=torch.device("cpu"),
         device_storage=torch.device("cpu"),
     ):
+        super().__init__()
+
         self.batch_size = batch_size
         self.device = device
 
index fa305cf..da7de75 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -62,12 +62,20 @@ class SignSTE(nn.Module):
             return s
 
 
+def loss_H(binary_logits, h_threshold=1):
+    p = binary_logits.sigmoid().mean(0)
+    h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
+    h.clamp_(max=h_threshold)
+    return h_threshold - h.mean()
+
+
 def train_encoder(
     train_input,
     test_input,
     depth=2,
     dim_hidden=48,
     nb_bits_per_token=8,
+    lambda_entropy=0.0,
     lr_start=1e-3,
     lr_end=1e-4,
     nb_epochs=10,
@@ -160,6 +168,9 @@ def train_encoder(
 
             train_loss = F.cross_entropy(output, input)
 
+            if lambda_entropy > 0:
+                loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
             acc_train_loss += train_loss.item() * input.size(0)
 
             optimizer.zero_grad()
@@ -238,7 +249,7 @@ def scene2tensor(xh, yh, scene, size):
     )
 
 
-def random_scene():
+def random_scene(nb_insert_attempts=3):
     scene = []
     colors = [
         ((Box.nb_rgb_levels - 1), 0, 0),
@@ -252,7 +263,7 @@ def random_scene():
         ),
     ]
 
-    for k in range(10):
+    for k in range(nb_insert_attempts):
         wh = torch.rand(2) * 0.2 + 0.2
         xy = torch.rand(2) * (1 - wh)
         c = colors[torch.randint(len(colors), (1,))]
@@ -286,14 +297,15 @@ def generate_episode(steps, size=64):
         xh, yh = tuple(x.item() for x in torch.rand(2))
 
         actions = torch.randint(len(effects), (len(steps),))
-        change = False
+        nb_changes = 0
 
         for s, a in zip(steps, actions):
             if s:
                 frames.append(scene2tensor(xh, yh, scene, size=size))
 
-            g, dx, dy = effects[a]
-            if g:
+            grasp, dx, dy = effects[a]
+
+            if grasp:
                 for b in scene:
                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
                         x, y = b.x, b.y
@@ -310,7 +322,7 @@ def generate_episode(steps, size=64):
                         else:
                             xh += dx
                             yh += dy
-                            change = True
+                            nb_changes += 1
             else:
                 x, y = xh, yh
                 xh += dx
@@ -318,7 +330,7 @@ def generate_episode(steps, size=64):
                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
                     xh, yh = x, y
 
-        if change:
+        if nb_changes > len(steps) // 3:
             break
 
     return frames, actions
@@ -352,12 +364,21 @@ def create_data_and_processors(
         steps = [True] + [False] * (nb_steps + 1) + [True]
 
     train_input, train_actions = generate_episodes(nb_train_samples, steps)
-    train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage)
+    train_input, train_actions = train_input.to(device_storage), train_actions.to(
+        device_storage
+    )
     test_input, test_actions = generate_episodes(nb_test_samples, steps)
-    test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage)
+    test_input, test_actions = test_input.to(device_storage), test_actions.to(
+        device_storage
+    )
 
     encoder, quantizer, decoder = train_encoder(
-        train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device
+        train_input,
+        test_input,
+        lambda_entropy=1.0,
+        nb_epochs=nb_epochs,
+        logger=logger,
+        device=device,
     )
     encoder.train(False)
     quantizer.train(False)
@@ -371,7 +392,7 @@ def create_data_and_processors(
         seq = []
         p = pow2.to(device)
         for x in input.split(batch_size):
-            x=x.to(device)
+            x = x.to(device)
             z = encoder(x)
             ze_bool = (quantizer(z) >= 0).long()
             output = (