Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 16 Sep 2023 10:24:21 +0000 (12:24 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 16 Sep 2023 10:24:21 +0000 (12:24 +0200)
main.py
snake.py
tasks.py

diff --git a/main.py b/main.py
index 7197414..cd37b94 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -159,11 +159,6 @@ parser.add_argument("--expr_result_max", type=int, default=99)
 
 parser.add_argument("--expr_input_file", type=str, default=None)
 
-##############################
-# World options
-
-parser.add_argument("--world_vqae_nb_epochs", type=int, default=25)
-
 ######################################################################
 
 args = parser.parse_args()
@@ -248,19 +243,12 @@ default_task_args = {
         "nb_train_samples": 50000,
         "nb_test_samples": 10000,
     },
-
     "mnist": {
         "model": "37M",
         "batch_size": 10,
         "nb_train_samples": 60000,
         "nb_test_samples": 10000,
     },
-    "world": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 25000,
-        "nb_test_samples": 1000,
-    },
 }
 
 if args.task in default_task_args:
@@ -514,16 +502,6 @@ elif args.task == "grid":
         device=device,
     )
 
-elif args.task == "world":
-    task = tasks.World(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.batch_size,
-        vqae_nb_epochs=args.world_vqae_nb_epochs,
-        logger=log_string,
-        device=device,
-    )
-
 else:
     raise ValueError(f"Unknown task {args.task}")
 
index 7c34941..8a16f9f 100755 (executable)
--- a/snake.py
+++ b/snake.py
@@ -111,45 +111,22 @@ def solver(input, ar_mask):
             # print(f'@2 {i=} {j=}')
 
 
+def seq2str(seq):
+    return "".join(["NESW123456789"[i] for i in seq])
+
+
 ######################################################################
 
 if __name__ == "__main__":
-    import cairo, numpy, math
-
-    color_name2rgb = {
-        "red": [255, 0, 0],
-        "green": [0, 128, 0],
-        "blue": [0, 0, 255],
-        "yellow": [255, 255, 0],
-        "orange": [255, 128, 0],
-        "maroon": [128, 0, 0],
-        "dark_red": [139, 0, 0],
-        "brown": [165, 42, 42],
-        "firebrick": [178, 34, 34],
-        "crimson": [220, 20, 60],
-        "tomato": [255, 99, 71],
-        "coral": [255, 127, 80],
-        "indian_red": [205, 92, 92],
-        "light_coral": [240, 128, 128],
-        "dark_salmon": [233, 150, 122],
-        "salmon": [250, 128, 114],
-    }
-
-    sequences, sequences_prior_visits, worlds, world_prior_visits = generate_sequences(
-        8, 6, 8, 5, 20, 10
+    train_input, train_prior_visits, _, _ = generate_sequences(
+        nb=20,
+        height=9,
+        width=12,
+        nb_colors=5,
+        length=50,
+        prompt_length=100,
     )
 
-    delta = 16
-    height, width = sequences.size(0) * 16, sequences.size(1) * 16
-    pixel_map = torch.ByteTensor(width, height, 4).fill_(0).numpy()
-    surface = cairo.ImageSurface.create_for_data(
-        pixel_map, cairo.FORMAT_ARGB32, width, height
-    )
-    ctx = cairo.Context(surface)
-    ctx.set_line_width(1.0)
-
-    ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
-
-    ctx.fill()
+    print([seq2str(s) for s in train_input])
 
 ######################################################################
index d787c59..183c3cf 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1550,125 +1550,3 @@ class Grid(Task):
 
 
 ######################################################################
-
-import world
-
-
-class World(Task):
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        vqae_nb_epochs,
-        logger=None,
-        device=torch.device("cpu"),
-        device_storage=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.device = device
-
-        (
-            train_frames,
-            train_action_seq,
-            test_frames,
-            test_action_seq,
-            self.frame2seq,
-            self.seq2frame,
-        ) = world.create_data_and_processors(
-            nb_train_samples,
-            nb_test_samples,
-            mode="first_last",
-            nb_steps=30,
-            nb_epochs=vqae_nb_epochs,
-            logger=logger,
-            device=device,
-            device_storage=device_storage,
-        )
-
-        train_frame_seq = self.frame2seq(train_frames).to(device_storage)
-        test_frame_seq = self.frame2seq(test_frames).to(device_storage)
-
-        nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
-        nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
-
-        self.len_frame_seq = train_frame_seq.size(1)
-        self.len_action_seq = train_action_seq.size(1)
-        self.nb_codes = nb_frame_codes + nb_action_codes
-
-        train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
-
-        train_action_seq += nb_frame_codes
-        self.train_input = torch.cat(
-            (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
-        )
-
-        test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
-        test_action_seq += nb_frame_codes
-        self.test_input = torch.cat(
-            (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
-        )
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch.to(self.device)
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        k = torch.arange(
-            2 * self.len_frame_seq + self.len_action_seq, device=self.device
-        )[None, :]
-
-        input = self.test_input[:64].to(self.device)
-        result = input.clone()
-
-        ar_mask = (
-            (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
-        )
-        result *= 1 - ar_mask
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        seq_start = input[:, : self.len_frame_seq]
-        seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
-        seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
-
-        result = torch.cat(
-            (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
-        )
-        result = result.reshape(-1, result.size(-1))
-
-        frames = self.seq2frame(result)
-        image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
-        torchvision.utils.save_image(
-            frames.float() / (world.Box.nb_rgb_levels - 1),
-            image_name,
-            nrow=12,
-            padding=1,
-            pad_value=0.0,
-        )
-        logger(f"wrote {image_name}")
-
-
-######################################################################