Update.
[picoclvr.git] / picoclvr.py
index 94c0f88..5da3943 100755 (executable)
 import torch, torchvision
 import torch.nn.functional as F
 
-colors = [
-    [255, 255, 255],
-    [255, 0, 0],
-    [0, 128, 0],
-    [0, 0, 255],
-    [255, 255, 0],
-    [0, 0, 0],
-    [128, 0, 0],
-    [139, 0, 0],
-    [165, 42, 42],
-    [178, 34, 34],
-    [220, 20, 60],
-    [255, 99, 71],
-    [255, 127, 80],
-    [205, 92, 92],
-    [240, 128, 128],
-    [233, 150, 122],
-    [250, 128, 114],
-    [255, 160, 122],
-    [255, 69, 0],
-    [255, 140, 0],
-    [255, 165, 0],
-    [255, 215, 0],
-    [184, 134, 11],
-    [218, 165, 32],
-    [238, 232, 170],
-    [189, 183, 107],
-    [240, 230, 140],
-    [128, 128, 0],
-    [154, 205, 50],
-    [85, 107, 47],
-    [107, 142, 35],
-    [124, 252, 0],
-    [127, 255, 0],
-    [173, 255, 47],
-    [0, 100, 0],
-    [34, 139, 34],
-    [0, 255, 0],
-    [50, 205, 50],
-    [144, 238, 144],
-    [152, 251, 152],
-    [143, 188, 143],
-    [0, 250, 154],
-    [0, 255, 127],
-    [46, 139, 87],
-    [102, 205, 170],
-    [60, 179, 113],
-    [32, 178, 170],
-    [47, 79, 79],
-    [0, 128, 128],
-    [0, 139, 139],
-    [0, 255, 255],
-    [0, 255, 255],
-    [224, 255, 255],
-    [0, 206, 209],
-    [64, 224, 208],
-    [72, 209, 204],
-    [175, 238, 238],
-    [127, 255, 212],
-    [176, 224, 230],
-    [95, 158, 160],
-    [70, 130, 180],
-    [100, 149, 237],
-    [0, 191, 255],
-    [30, 144, 255],
-    [173, 216, 230],
-    [135, 206, 235],
-    [135, 206, 250],
-    [25, 25, 112],
-    [0, 0, 128],
-    [0, 0, 139],
-    [0, 0, 205],
-    [65, 105, 225],
-    [138, 43, 226],
-    [75, 0, 130],
-    [72, 61, 139],
-    [106, 90, 205],
-    [123, 104, 238],
-    [147, 112, 219],
-    [139, 0, 139],
-    [148, 0, 211],
-    [153, 50, 204],
-    [186, 85, 211],
-    [128, 0, 128],
-    [216, 191, 216],
-    [221, 160, 221],
-    [238, 130, 238],
-    [255, 0, 255],
-    [218, 112, 214],
-    [199, 21, 133],
-    [219, 112, 147],
-    [255, 20, 147],
-    [255, 105, 180],
-    [255, 182, 193],
-    [255, 192, 203],
-    [250, 235, 215],
-    [245, 245, 220],
-    [255, 228, 196],
-    [255, 235, 205],
-    [245, 222, 179],
-    [255, 248, 220],
-    [255, 250, 205],
-    [250, 250, 210],
-    [255, 255, 224],
-    [139, 69, 19],
-    [160, 82, 45],
-    [210, 105, 30],
-    [205, 133, 63],
-    [244, 164, 96],
-    [222, 184, 135],
-    [210, 180, 140],
-    [188, 143, 143],
-    [255, 228, 181],
-    [255, 222, 173],
-    [255, 218, 185],
-    [255, 228, 225],
-    [255, 240, 245],
-    [250, 240, 230],
-    [253, 245, 230],
-    [255, 239, 213],
-    [255, 245, 238],
-    [245, 255, 250],
-    [112, 128, 144],
-    [119, 136, 153],
-    [176, 196, 222],
-    [230, 230, 250],
-    [255, 250, 240],
-    [240, 248, 255],
-    [248, 248, 255],
-    [240, 255, 240],
-    [255, 255, 240],
-    [240, 255, 255],
-    [255, 250, 250],
-    [192, 192, 192],
-    [220, 220, 220],
-    [245, 245, 245],
-]
-
-color_names = [
-    "white",
-    "red",
-    "green",
-    "blue",
-    "yellow",
-    "black",
-    "maroon",
-    "dark_red",
-    "brown",
-    "firebrick",
-    "crimson",
-    "tomato",
-    "coral",
-    "indian_red",
-    "light_coral",
-    "dark_salmon",
-    "salmon",
-    "light_salmon",
-    "orange_red",
-    "dark_orange",
-    "orange",
-    "gold",
-    "dark_golden_rod",
-    "golden_rod",
-    "pale_golden_rod",
-    "dark_khaki",
-    "khaki",
-    "olive",
-    "yellow_green",
-    "dark_olive_green",
-    "olive_drab",
-    "lawn_green",
-    "chartreuse",
-    "green_yellow",
-    "dark_green",
-    "forest_green",
-    "lime",
-    "lime_green",
-    "light_green",
-    "pale_green",
-    "dark_sea_green",
-    "medium_spring_green",
-    "spring_green",
-    "sea_green",
-    "medium_aqua_marine",
-    "medium_sea_green",
-    "light_sea_green",
-    "dark_slate_gray",
-    "teal",
-    "dark_cyan",
-    "aqua",
-    "cyan",
-    "light_cyan",
-    "dark_turquoise",
-    "turquoise",
-    "medium_turquoise",
-    "pale_turquoise",
-    "aqua_marine",
-    "powder_blue",
-    "cadet_blue",
-    "steel_blue",
-    "corn_flower_blue",
-    "deep_sky_blue",
-    "dodger_blue",
-    "light_blue",
-    "sky_blue",
-    "light_sky_blue",
-    "midnight_blue",
-    "navy",
-    "dark_blue",
-    "medium_blue",
-    "royal_blue",
-    "blue_violet",
-    "indigo",
-    "dark_slate_blue",
-    "slate_blue",
-    "medium_slate_blue",
-    "medium_purple",
-    "dark_magenta",
-    "dark_violet",
-    "dark_orchid",
-    "medium_orchid",
-    "purple",
-    "thistle",
-    "plum",
-    "violet",
-    "magenta",
-    "orchid",
-    "medium_violet_red",
-    "pale_violet_red",
-    "deep_pink",
-    "hot_pink",
-    "light_pink",
-    "pink",
-    "antique_white",
-    "beige",
-    "bisque",
-    "blanched_almond",
-    "wheat",
-    "corn_silk",
-    "lemon_chiffon",
-    "light_golden_rod_yellow",
-    "light_yellow",
-    "saddle_brown",
-    "sienna",
-    "chocolate",
-    "peru",
-    "sandy_brown",
-    "burly_wood",
-    "tan",
-    "rosy_brown",
-    "moccasin",
-    "navajo_white",
-    "peach_puff",
-    "misty_rose",
-    "lavender_blush",
-    "linen",
-    "old_lace",
-    "papaya_whip",
-    "sea_shell",
-    "mint_cream",
-    "slate_gray",
-    "light_slate_gray",
-    "light_steel_blue",
-    "lavender",
-    "floral_white",
-    "alice_blue",
-    "ghost_white",
-    "honeydew",
-    "ivory",
-    "azure",
-    "snow",
-    "silver",
-    "gainsboro",
-    "white_smoke",
-]
-
-color_id = dict([(n, k) for k, n in enumerate(color_names)])
-color_tokens = dict([(n, c) for n, c in zip(color_names, colors)])
+color_name2rgb = {
+    "white": [255, 255, 255],
+    "red": [255, 0, 0],
+    "green": [0, 128, 0],
+    "blue": [0, 0, 255],
+    "yellow": [255, 255, 0],
+    "black": [0, 0, 0],
+    "maroon": [128, 0, 0],
+    "dark_red": [139, 0, 0],
+    "brown": [165, 42, 42],
+    "firebrick": [178, 34, 34],
+    "crimson": [220, 20, 60],
+    "tomato": [255, 99, 71],
+    "coral": [255, 127, 80],
+    "indian_red": [205, 92, 92],
+    "light_coral": [240, 128, 128],
+    "dark_salmon": [233, 150, 122],
+    "salmon": [250, 128, 114],
+    "light_salmon": [255, 160, 122],
+    "orange_red": [255, 69, 0],
+    "dark_orange": [255, 140, 0],
+    "orange": [255, 165, 0],
+    "gold": [255, 215, 0],
+    "dark_golden_rod": [184, 134, 11],
+    "golden_rod": [218, 165, 32],
+    "pale_golden_rod": [238, 232, 170],
+    "dark_khaki": [189, 183, 107],
+    "khaki": [240, 230, 140],
+    "olive": [128, 128, 0],
+    "yellow_green": [154, 205, 50],
+    "dark_olive_green": [85, 107, 47],
+    "olive_drab": [107, 142, 35],
+    "lawn_green": [124, 252, 0],
+    "chartreuse": [127, 255, 0],
+    "green_yellow": [173, 255, 47],
+    "dark_green": [0, 100, 0],
+    "forest_green": [34, 139, 34],
+    "lime": [0, 255, 0],
+    "lime_green": [50, 205, 50],
+    "light_green": [144, 238, 144],
+    "pale_green": [152, 251, 152],
+    "dark_sea_green": [143, 188, 143],
+    "medium_spring_green": [0, 250, 154],
+    "spring_green": [0, 255, 127],
+    "sea_green": [46, 139, 87],
+    "medium_aqua_marine": [102, 205, 170],
+    "medium_sea_green": [60, 179, 113],
+    "light_sea_green": [32, 178, 170],
+    "dark_slate_gray": [47, 79, 79],
+    "teal": [0, 128, 128],
+    "dark_cyan": [0, 139, 139],
+    "aqua": [0, 255, 255],
+    "cyan": [0, 255, 255],
+    "light_cyan": [224, 255, 255],
+    "dark_turquoise": [0, 206, 209],
+    "turquoise": [64, 224, 208],
+    "medium_turquoise": [72, 209, 204],
+    "pale_turquoise": [175, 238, 238],
+    "aqua_marine": [127, 255, 212],
+    "powder_blue": [176, 224, 230],
+    "cadet_blue": [95, 158, 160],
+    "steel_blue": [70, 130, 180],
+    "corn_flower_blue": [100, 149, 237],
+    "deep_sky_blue": [0, 191, 255],
+    "dodger_blue": [30, 144, 255],
+    "light_blue": [173, 216, 230],
+    "sky_blue": [135, 206, 235],
+    "light_sky_blue": [135, 206, 250],
+    "midnight_blue": [25, 25, 112],
+    "navy": [0, 0, 128],
+    "dark_blue": [0, 0, 139],
+    "medium_blue": [0, 0, 205],
+    "royal_blue": [65, 105, 225],
+    "blue_violet": [138, 43, 226],
+    "indigo": [75, 0, 130],
+    "dark_slate_blue": [72, 61, 139],
+    "slate_blue": [106, 90, 205],
+    "medium_slate_blue": [123, 104, 238],
+    "medium_purple": [147, 112, 219],
+    "dark_magenta": [139, 0, 139],
+    "dark_violet": [148, 0, 211],
+    "dark_orchid": [153, 50, 204],
+    "medium_orchid": [186, 85, 211],
+    "purple": [128, 0, 128],
+    "thistle": [216, 191, 216],
+    "plum": [221, 160, 221],
+    "violet": [238, 130, 238],
+    "magenta": [255, 0, 255],
+    "orchid": [218, 112, 214],
+    "medium_violet_red": [199, 21, 133],
+    "pale_violet_red": [219, 112, 147],
+    "deep_pink": [255, 20, 147],
+    "hot_pink": [255, 105, 180],
+    "light_pink": [255, 182, 193],
+    "pink": [255, 192, 203],
+    "antique_white": [250, 235, 215],
+    "beige": [245, 245, 220],
+    "bisque": [255, 228, 196],
+    "blanched_almond": [255, 235, 205],
+    "wheat": [245, 222, 179],
+    "corn_silk": [255, 248, 220],
+    "lemon_chiffon": [255, 250, 205],
+    "light_golden_rod_yellow": [250, 250, 210],
+    "light_yellow": [255, 255, 224],
+    "saddle_brown": [139, 69, 19],
+    "sienna": [160, 82, 45],
+    "chocolate": [210, 105, 30],
+    "peru": [205, 133, 63],
+    "sandy_brown": [244, 164, 96],
+    "burly_wood": [222, 184, 135],
+    "tan": [210, 180, 140],
+    "rosy_brown": [188, 143, 143],
+    "moccasin": [255, 228, 181],
+    "navajo_white": [255, 222, 173],
+    "peach_puff": [255, 218, 185],
+    "misty_rose": [255, 228, 225],
+    "lavender_blush": [255, 240, 245],
+    "linen": [250, 240, 230],
+    "old_lace": [253, 245, 230],
+    "papaya_whip": [255, 239, 213],
+    "sea_shell": [255, 245, 238],
+    "mint_cream": [245, 255, 250],
+    "slate_gray": [112, 128, 144],
+    "light_slate_gray": [119, 136, 153],
+    "light_steel_blue": [176, 196, 222],
+    "lavender": [230, 230, 250],
+    "floral_white": [255, 250, 240],
+    "alice_blue": [240, 248, 255],
+    "ghost_white": [248, 248, 255],
+    "honeydew": [240, 255, 240],
+    "ivory": [255, 255, 240],
+    "azure": [240, 255, 255],
+    "snow": [255, 250, 250],
+    "silver": [192, 192, 192],
+    "gainsboro": [220, 220, 220],
+    "white_smoke": [245, 245, 245],
+}
+
+color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
+color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
 
 ######################################################################
 
@@ -293,7 +155,7 @@ color_tokens = dict([(n, c) for n, c in zip(color_names, colors)])
 def all_properties(height, width, nb_squares, square_i, square_j, square_c):
     s = []
 
-    for r, c_r in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
+    for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
         s += [f"there is {c_r}"]
 
         if square_i[r] >= height - height // 3:
@@ -305,7 +167,9 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c):
         if square_j[r] < width // 3:
             s += [f"{c_r} left"]
 
-        for t, c_t in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
+        for t, c_t in [
+            (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
+        ]:
             if square_i[r] > square_i[t]:
                 s += [f"{c_r} below {c_t}"]
             if square_i[r] < square_i[t]:
@@ -332,13 +196,11 @@ def generate(
     nb_colors=5,
     pruner=None,
 ):
-
-    assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
+    assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
 
     descr = []
 
     for n in range(nb):
-
         nb_squares = torch.randint(max_nb_squares, (1,)) + 1
         square_position = torch.randperm(height * width)[:nb_squares]
 
@@ -347,7 +209,7 @@ def generate(
         square_i = square_position.div(width, rounding_mode="floor")
         square_j = square_position % width
 
-        img = [0] * height * width
+        img = torch.zeros(height * width, dtype=torch.int64)
         for k in range(nb_squares):
             img[square_position[k]] = square_c[k]
 
@@ -364,7 +226,7 @@ def generate(
         s = (
             " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
             + " <img> "
-            + " ".join([f"{color_names[n]}" for n in img])
+            + " ".join([f"{color_id2name[n.item()]}" for n in img])
         )
 
         descr += [s]
@@ -377,31 +239,24 @@ def generate(
 # Extracts the image after <img> in descr as a 1x3xHxW tensor
 
 
-def descr2img(descr, n, height, width):
-
-    if type(descr) == list:
-        return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
-
-    if type(n) == list:
-        return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
-            0
-        )
+def descr2img(descr, height, width):
+    result = []
 
     def token2color(t):
         try:
-            return color_tokens[t]
+            return color_name2rgb[t]
         except KeyError:
             return [128, 128, 128]
 
-    d = descr.split("<img>")
-    d = d[n + 1] if len(d) > n + 1 else ""
-    d = d.strip().split(" ")[: height * width]
-    d = d + ["<unk>"] * (height * width - len(d))
-    d = [token2color(t) for t in d]
-    img = torch.tensor(d).permute(1, 0)
-    img = img.reshape(1, 3, height, width)
+    for d in descr:
+        d = d.split("<img>")[1]
+        d = d.strip().split(" ")[: height * width]
+        d = d + ["<unk>"] * (height * width - len(d))
+        d = [token2color(t) for t in d]
+        img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
+        result.append(img)
 
-    return img
+    return torch.cat(result, 0)
 
 
 ######################################################################
@@ -410,25 +265,24 @@ def descr2img(descr, n, height, width):
 
 
 def descr2properties(descr, height, width):
-
     if type(descr) == list:
         return [descr2properties(d, height, width) for d in descr]
 
     d = descr.split("<img>")
-    d = d[-1] if len(d) > 1 else ""
-    d = d.strip().split(" ")[: height * width]
-    if len(d) != height * width:
+    img_tokens = d[-1] if len(d) > 1 else ""
+    img_tokens = img_tokens.strip().split(" ")[: height * width]
+    if len(img_tokens) != height * width:
         return []
 
     seen = {}
-    for k, x in enumerate(d):
-        if x != color_names[0]:
-            if x in color_tokens:
+    for k, x in enumerate(img_tokens):
+        if x != color_id2name[0]:
+            if x in color_name2rgb:
                 if x in seen:
                     return []
             else:
                 return []
-            seen[x] = (color_id[x], k // width, k % width)
+            seen[x] = (color_name2id[x], k // width, k % width)
 
     square_infos = tuple(zip(*seen.values()))
 
@@ -455,7 +309,6 @@ def descr2properties(descr, height, width):
 
 
 def nb_properties(descr, height, width, pruner=None):
-
     if type(descr) == list:
         return [nb_properties(d, height, width, pruner) for d in descr]
 
@@ -489,7 +342,7 @@ if __name__ == "__main__":
             for d in descr:
                 f.write(f"{d}\n\n")
 
-        img = descr2img(descr, n=0, height=12, width=16)
+        img = descr2img(descr, height=12, width=16)
         if img.size(0) == 1:
             img = F.pad(img, (1, 1, 1, 1), value=64)