Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 21 Jun 2023 20:10:48 +0000 (22:10 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 21 Jun 2023 20:10:48 +0000 (22:10 +0200)
main.py
snake.py

diff --git a/main.py b/main.py
index 9679236..7cb8d4f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -102,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8)
 
 parser.add_argument("--snake_nb_colors", type=int, default=5)
 
-parser.add_argument("--snake_length", type=int, default=400)
+parser.add_argument("--snake_length", type=int, default=200)
 
 ######################################################################
 
@@ -143,8 +143,8 @@ default_args = {
         "batch_size": 25,
     },
     "snake": {
-        "nb_epochs": 25,
-        "batch_size": 20,
+        "nb_epochs": 5,
+        "batch_size": 25,
     },
 }
 
@@ -689,7 +689,7 @@ class TaskSnake(Task):
         self.device = device
         self.prompt_length = prompt_length
 
-        self.train_input, self.train_prior_visits = snake.generate_sequences(
+        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
             nb_train_samples,
             height,
             width,
@@ -698,7 +698,7 @@ class TaskSnake(Task):
             prompt_length,
             self.device,
         )
-        self.test_input, self.test_prior_visits = snake.generate_sequences(
+        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
             nb_test_samples,
             height,
             width,
index eb46a07..7c34941 100755 (executable)
--- a/snake.py
+++ b/snake.py
@@ -13,7 +13,7 @@ def generate_sequences(
     nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
 ):
     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
-    nb_prior_visits = torch.zeros(nb, height, width, device=device)
+    world_prior_visits = torch.zeros(nb, height, width, device=device)
 
     # nb x 2
     snake_position = torch.cat(
@@ -70,17 +70,17 @@ def generate_sequences(
         snake_direction = snake_next_direction[i, j]
 
         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
-        sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+        sequences_prior_visits[:, 2 * l] = world_prior_visits[
             i, snake_position[:, 0], snake_position[:, 1]
         ]
         if l < prompt_length:
-            nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+            world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
         sequences[:, 2 * l + 1] = snake_direction
 
         # nb x 2
         snake_position = snake_next_position[i, j]
 
-    return sequences, sequences_prior_visits
+    return sequences, sequences_prior_visits, worlds, world_prior_visits
 
 
 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
@@ -114,32 +114,42 @@ def solver(input, ar_mask):
 ######################################################################
 
 if __name__ == "__main__":
-    for n in range(16):
-        descr = generate(nb=1, height=12, width=16)
-
-        print(nb_properties(descr, height=12, width=16))
-
-        with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
-            for d in descr:
-                f.write(f"{d}\n\n")
-
-        img = descr2img(descr, height=12, width=16)
-        if img.size(0) == 1:
-            img = F.pad(img, (1, 1, 1, 1), value=64)
+    import cairo, numpy, math
+
+    color_name2rgb = {
+        "red": [255, 0, 0],
+        "green": [0, 128, 0],
+        "blue": [0, 0, 255],
+        "yellow": [255, 255, 0],
+        "orange": [255, 128, 0],
+        "maroon": [128, 0, 0],
+        "dark_red": [139, 0, 0],
+        "brown": [165, 42, 42],
+        "firebrick": [178, 34, 34],
+        "crimson": [220, 20, 60],
+        "tomato": [255, 99, 71],
+        "coral": [255, 127, 80],
+        "indian_red": [205, 92, 92],
+        "light_coral": [240, 128, 128],
+        "dark_salmon": [233, 150, 122],
+        "salmon": [250, 128, 114],
+    }
+
+    sequences, sequences_prior_visits, worlds, world_prior_visits = generate_sequences(
+        8, 6, 8, 5, 20, 10
+    )
 
-        torchvision.utils.save_image(
-            img / 255.0,
-            f"picoclvr_example_{n:02d}.png",
-            padding=1,
-            nrow=4,
-            pad_value=0.8,
-        )
+    delta = 16
+    height, width = sequences.size(0) * 16, sequences.size(1) * 16
+    pixel_map = torch.ByteTensor(width, height, 4).fill_(0).numpy()
+    surface = cairo.ImageSurface.create_for_data(
+        pixel_map, cairo.FORMAT_ARGB32, width, height
+    )
+    ctx = cairo.Context(surface)
+    ctx.set_line_width(1.0)
 
-    import time
+    ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
 
-    start_time = time.perf_counter()
-    descr = generate(nb=1000, height=12, width=16)
-    end_time = time.perf_counter()
-    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+    ctx.fill()
 
 ######################################################################