Added default configurations and reformated with black.
[mygpt.git] / picoclvr.py
index 059e352..fb791fe 100755 (executable)
 
 import torch, torchvision
 
-colors = [
-    [ 255, 255, 255 ], [ 255, 0, 0 ], [ 0, 128, 0 ], [ 0, 0, 255 ], [ 255, 255, 0 ],
-    [ 0, 0, 0 ], [ 128, 0, 0 ], [ 139, 0, 0 ], [ 165, 42, 42 ], [ 178, 34, 34 ],
-    [ 220, 20, 60 ], [ 255, 99, 71 ], [ 255, 127, 80 ], [ 205, 92, 92 ], [ 240, 128, 128 ],
-    [ 233, 150, 122 ], [ 250, 128, 114 ], [ 255, 160, 122 ], [ 255, 69, 0 ], [ 255, 140, 0 ],
-    [ 255, 165, 0 ], [ 255, 215, 0 ], [ 184, 134, 11 ], [ 218, 165, 32 ], [ 238, 232, 170 ],
-    [ 189, 183, 107 ], [ 240, 230, 140 ], [ 128, 128, 0 ], [ 154, 205, 50 ], [ 85, 107, 47 ],
-    [ 107, 142, 35 ], [ 124, 252, 0 ], [ 127, 255, 0 ], [ 173, 255, 47 ], [ 0, 100, 0 ],
-    [ 34, 139, 34 ], [ 0, 255, 0 ], [ 50, 205, 50 ], [ 144, 238, 144 ], [ 152, 251, 152 ],
-    [ 143, 188, 143 ], [ 0, 250, 154 ], [ 0, 255, 127 ], [ 46, 139, 87 ], [ 102, 205, 170 ],
-    [ 60, 179, 113 ], [ 32, 178, 170 ], [ 47, 79, 79 ], [ 0, 128, 128 ], [ 0, 139, 139 ],
-    [ 0, 255, 255 ], [ 0, 255, 255 ], [ 224, 255, 255 ], [ 0, 206, 209 ], [ 64, 224, 208 ],
-    [ 72, 209, 204 ], [ 175, 238, 238 ], [ 127, 255, 212 ], [ 176, 224, 230 ], [ 95, 158, 160 ],
-    [ 70, 130, 180 ], [ 100, 149, 237 ], [ 0, 191, 255 ], [ 30, 144, 255 ], [ 173, 216, 230 ],
-    [ 135, 206, 235 ], [ 135, 206, 250 ], [ 25, 25, 112 ], [ 0, 0, 128 ], [ 0, 0, 139 ],
-    [ 0, 0, 205 ], [ 65, 105, 225 ], [ 138, 43, 226 ], [ 75, 0, 130 ], [ 72, 61, 139 ],
-    [ 106, 90, 205 ], [ 123, 104, 238 ], [ 147, 112, 219 ], [ 139, 0, 139 ], [ 148, 0, 211 ],
-    [ 153, 50, 204 ], [ 186, 85, 211 ], [ 128, 0, 128 ], [ 216, 191, 216 ], [ 221, 160, 221 ],
-    [ 238, 130, 238 ], [ 255, 0, 255 ], [ 218, 112, 214 ], [ 199, 21, 133 ], [ 219, 112, 147 ],
-    [ 255, 20, 147 ], [ 255, 105, 180 ], [ 255, 182, 193 ], [ 255, 192, 203 ], [ 250, 235, 215 ],
-    [ 245, 245, 220 ], [ 255, 228, 196 ], [ 255, 235, 205 ], [ 245, 222, 179 ], [ 255, 248, 220 ],
-    [ 255, 250, 205 ], [ 250, 250, 210 ], [ 255, 255, 224 ], [ 139, 69, 19 ], [ 160, 82, 45 ],
-    [ 210, 105, 30 ], [ 205, 133, 63 ], [ 244, 164, 96 ], [ 222, 184, 135 ], [ 210, 180, 140 ],
-    [ 188, 143, 143 ], [ 255, 228, 181 ], [ 255, 222, 173 ], [ 255, 218, 185 ], [ 255, 228, 225 ],
-    [ 255, 240, 245 ], [ 250, 240, 230 ], [ 253, 245, 230 ], [ 255, 239, 213 ], [ 255, 245, 238 ],
-    [ 245, 255, 250 ], [ 112, 128, 144 ], [ 119, 136, 153 ], [ 176, 196, 222 ], [ 230, 230, 250 ],
-    [ 255, 250, 240 ], [ 240, 248, 255 ], [ 248, 248, 255 ], [ 240, 255, 240 ], [ 255, 255, 240 ],
-    [ 240, 255, 255 ], [ 255, 250, 250 ], [ 192, 192, 192 ], [ 220, 220, 220 ], [ 245, 245, 245 ],
-]
-
-color_names = [
-    'white', 'red', 'green', 'blue', 'yellow',
-    'black', 'maroon', 'dark_red', 'brown', 'firebrick',
-    'crimson', 'tomato', 'coral', 'indian_red', 'light_coral',
-    'dark_salmon', 'salmon', 'light_salmon', 'orange_red', 'dark_orange',
-    'orange', 'gold', 'dark_golden_rod', 'golden_rod', 'pale_golden_rod',
-    'dark_khaki', 'khaki', 'olive', 'yellow_green', 'dark_olive_green',
-    'olive_drab', 'lawn_green', 'chartreuse', 'green_yellow', 'dark_green',
-    'forest_green', 'lime', 'lime_green', 'light_green', 'pale_green',
-    'dark_sea_green', 'medium_spring_green', 'spring_green', 'sea_green', 'medium_aqua_marine',
-    'medium_sea_green', 'light_sea_green', 'dark_slate_gray', 'teal', 'dark_cyan',
-    'aqua', 'cyan', 'light_cyan', 'dark_turquoise', 'turquoise',
-    'medium_turquoise', 'pale_turquoise', 'aqua_marine', 'powder_blue', 'cadet_blue',
-    'steel_blue', 'corn_flower_blue', 'deep_sky_blue', 'dodger_blue', 'light_blue',
-    'sky_blue', 'light_sky_blue', 'midnight_blue', 'navy', 'dark_blue',
-    'medium_blue', 'royal_blue', 'blue_violet', 'indigo', 'dark_slate_blue',
-    'slate_blue', 'medium_slate_blue', 'medium_purple', 'dark_magenta', 'dark_violet',
-    'dark_orchid', 'medium_orchid', 'purple', 'thistle', 'plum',
-    'violet', 'magenta', 'orchid', 'medium_violet_red', 'pale_violet_red',
-    'deep_pink', 'hot_pink', 'light_pink', 'pink', 'antique_white',
-    'beige', 'bisque', 'blanched_almond', 'wheat', 'corn_silk',
-    'lemon_chiffon', 'light_golden_rod_yellow', 'light_yellow', 'saddle_brown', 'sienna',
-    'chocolate', 'peru', 'sandy_brown', 'burly_wood', 'tan',
-    'rosy_brown', 'moccasin', 'navajo_white', 'peach_puff', 'misty_rose',
-    'lavender_blush', 'linen', 'old_lace', 'papaya_whip', 'sea_shell',
-    'mint_cream', 'slate_gray', 'light_slate_gray', 'light_steel_blue', 'lavender',
-    'floral_white', 'alice_blue', 'ghost_white', 'honeydew', 'ivory',
-    'azure', 'snow', 'silver', 'gainsboro', 'white_smoke',
-]
-
-color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] )
-color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
+color_tokens = {
+    "white": [255, 255, 255],
+    "red": [255, 0, 0],
+    "green": [0, 128, 0],
+    "blue": [0, 0, 255],
+    "yellow": [255, 255, 0],
+    "black": [0, 0, 0],
+    "maroon": [128, 0, 0],
+    "dark_red": [139, 0, 0],
+    "brown": [165, 42, 42],
+    "firebrick": [178, 34, 34],
+    "crimson": [220, 20, 60],
+    "tomato": [255, 99, 71],
+    "coral": [255, 127, 80],
+    "indian_red": [205, 92, 92],
+    "light_coral": [240, 128, 128],
+    "dark_salmon": [233, 150, 122],
+    "salmon": [250, 128, 114],
+    "light_salmon": [255, 160, 122],
+    "orange_red": [255, 69, 0],
+    "dark_orange": [255, 140, 0],
+    "orange": [255, 165, 0],
+    "gold": [255, 215, 0],
+    "dark_golden_rod": [184, 134, 11],
+    "golden_rod": [218, 165, 32],
+    "pale_golden_rod": [238, 232, 170],
+    "dark_khaki": [189, 183, 107],
+    "khaki": [240, 230, 140],
+    "olive": [128, 128, 0],
+    "yellow_green": [154, 205, 50],
+    "dark_olive_green": [85, 107, 47],
+    "olive_drab": [107, 142, 35],
+    "lawn_green": [124, 252, 0],
+    "chartreuse": [127, 255, 0],
+    "green_yellow": [173, 255, 47],
+    "dark_green": [0, 100, 0],
+    "forest_green": [34, 139, 34],
+    "lime": [0, 255, 0],
+    "lime_green": [50, 205, 50],
+    "light_green": [144, 238, 144],
+    "pale_green": [152, 251, 152],
+    "dark_sea_green": [143, 188, 143],
+    "medium_spring_green": [0, 250, 154],
+    "spring_green": [0, 255, 127],
+    "sea_green": [46, 139, 87],
+    "medium_aqua_marine": [102, 205, 170],
+    "medium_sea_green": [60, 179, 113],
+    "light_sea_green": [32, 178, 170],
+    "dark_slate_gray": [47, 79, 79],
+    "teal": [0, 128, 128],
+    "dark_cyan": [0, 139, 139],
+    "aqua": [0, 255, 255],
+    "cyan": [0, 255, 255],
+    "light_cyan": [224, 255, 255],
+    "dark_turquoise": [0, 206, 209],
+    "turquoise": [64, 224, 208],
+    "medium_turquoise": [72, 209, 204],
+    "pale_turquoise": [175, 238, 238],
+    "aqua_marine": [127, 255, 212],
+    "powder_blue": [176, 224, 230],
+    "cadet_blue": [95, 158, 160],
+    "steel_blue": [70, 130, 180],
+    "corn_flower_blue": [100, 149, 237],
+    "deep_sky_blue": [0, 191, 255],
+    "dodger_blue": [30, 144, 255],
+    "light_blue": [173, 216, 230],
+    "sky_blue": [135, 206, 235],
+    "light_sky_blue": [135, 206, 250],
+    "midnight_blue": [25, 25, 112],
+    "navy": [0, 0, 128],
+    "dark_blue": [0, 0, 139],
+    "medium_blue": [0, 0, 205],
+    "royal_blue": [65, 105, 225],
+    "blue_violet": [138, 43, 226],
+    "indigo": [75, 0, 130],
+    "dark_slate_blue": [72, 61, 139],
+    "slate_blue": [106, 90, 205],
+    "medium_slate_blue": [123, 104, 238],
+    "medium_purple": [147, 112, 219],
+    "dark_magenta": [139, 0, 139],
+    "dark_violet": [148, 0, 211],
+    "dark_orchid": [153, 50, 204],
+    "medium_orchid": [186, 85, 211],
+    "purple": [128, 0, 128],
+    "thistle": [216, 191, 216],
+    "plum": [221, 160, 221],
+    "violet": [238, 130, 238],
+    "magenta": [255, 0, 255],
+    "orchid": [218, 112, 214],
+    "medium_violet_red": [199, 21, 133],
+    "pale_violet_red": [219, 112, 147],
+    "deep_pink": [255, 20, 147],
+    "hot_pink": [255, 105, 180],
+    "light_pink": [255, 182, 193],
+    "pink": [255, 192, 203],
+    "antique_white": [250, 235, 215],
+    "beige": [245, 245, 220],
+    "bisque": [255, 228, 196],
+    "blanched_almond": [255, 235, 205],
+    "wheat": [245, 222, 179],
+    "corn_silk": [255, 248, 220],
+    "lemon_chiffon": [255, 250, 205],
+    "light_golden_rod_yellow": [250, 250, 210],
+    "light_yellow": [255, 255, 224],
+    "saddle_brown": [139, 69, 19],
+    "sienna": [160, 82, 45],
+    "chocolate": [210, 105, 30],
+    "peru": [205, 133, 63],
+    "sandy_brown": [244, 164, 96],
+    "burly_wood": [222, 184, 135],
+    "tan": [210, 180, 140],
+    "rosy_brown": [188, 143, 143],
+    "moccasin": [255, 228, 181],
+    "navajo_white": [255, 222, 173],
+    "peach_puff": [255, 218, 185],
+    "misty_rose": [255, 228, 225],
+    "lavender_blush": [255, 240, 245],
+    "linen": [250, 240, 230],
+    "old_lace": [253, 245, 230],
+    "papaya_whip": [255, 239, 213],
+    "sea_shell": [255, 245, 238],
+    "mint_cream": [245, 255, 250],
+    "slate_gray": [112, 128, 144],
+    "light_slate_gray": [119, 136, 153],
+    "light_steel_blue": [176, 196, 222],
+    "lavender": [230, 230, 250],
+    "floral_white": [255, 250, 240],
+    "alice_blue": [240, 248, 255],
+    "ghost_white": [248, 248, 255],
+    "honeydew": [240, 255, 240],
+    "ivory": [255, 255, 240],
+    "azure": [240, 255, 255],
+    "snow": [255, 250, 250],
+    "silver": [192, 192, 192],
+    "gainsboro": [220, 220, 220],
+    "white_smoke": [245, 245, 245],
+}
+
+color_id = dict([(n, k) for k, n in enumerate(color_tokens.keys())])
+color_names = dict([(k, n) for k, n in enumerate(color_tokens.keys())])
 
 ######################################################################
 
-def all_properties(height, width, nb_squares, square_i, square_j, square_c):
-    s = [ ]
-
-    for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
-        s += [ f'there is {c}' ]
 
-        if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
-        if square_i[r] < height//3: s += [ f'{c} top' ]
-        if square_j[r] >= width - width//3: s += [ f'{c} right' ]
-        if square_j[r] < width//3: s += [ f'{c} left' ]
-
-        for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
-            if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
-            if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
-            if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
-            if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
+def all_properties(height, width, nb_squares, square_i, square_j, square_c):
+    s = []
+
+    for r, c in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]:
+        s += [f"there is {c}"]
+
+        if square_i[r] >= height - height // 3:
+            s += [f"{c} bottom"]
+        if square_i[r] < height // 3:
+            s += [f"{c} top"]
+        if square_j[r] >= width - width // 3:
+            s += [f"{c} right"]
+        if square_j[r] < width // 3:
+            s += [f"{c} left"]
+
+        for t, d in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]:
+            if square_i[r] > square_i[t]:
+                s += [f"{c} below {d}"]
+            if square_i[r] < square_i[t]:
+                s += [f"{c} above {d}"]
+            if square_j[r] > square_j[t]:
+                s += [f"{c} right of {d}"]
+            if square_j[r] < square_j[t]:
+                s += [f"{c} left of {d}"]
 
     return s
 
+
 ######################################################################
 
-def generate(nb, height, width,
-             max_nb_squares = 5, max_nb_properties = 10,
-             nb_colors = 5,
-             pruning_criterion = None):
+
+def generate(
+    nb,
+    height,
+    width,
+    max_nb_squares=5,
+    max_nb_properties=10,
+    nb_colors=5,
+    pruning_criterion=None,
+):
 
     assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
 
-    descr = [ ]
+    descr = []
 
     for n in range(nb):
 
@@ -108,70 +202,77 @@ def generate(nb, height, width,
         square_position = torch.randperm(height * width)[:nb_squares]
         # color 0 is white and reserved for the background
         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
-        square_i = square_position.div(width, rounding_mode = 'floor')
+        square_i = square_position.div(width, rounding_mode="floor")
         square_j = square_position % width
 
-        img = [ 0 ] * height * width
-        for k in range(nb_squares): img[square_position[k]] = square_c[k]
+        img = torch.zeros(height * width, dtype=torch.int64)
+        for k in range(nb_squares):
+            img[square_position[k]] = square_c[k]
 
         # generates all the true properties
 
         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
 
         if pruning_criterion is not None:
-            s = list(filter(pruning_criterion,s))
+            s = list(filter(pruning_criterion, s))
 
         # pick at most max_nb_properties at random
 
         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
-        s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_properties] ] )
-        s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
+        s = " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
+        s += " <img> " + " ".join([f"{color_names[n.item()]}" for n in img])
 
-        descr += [ s ]
+        descr += [s]
 
     return descr
 
+
 ######################################################################
 
+
 def descr2img(descr, height, width):
 
     if type(descr) == list:
-        return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
+        return torch.cat([descr2img(d, height, width) for d in descr], 0)
 
     def token2color(t):
         try:
             return color_tokens[t]
         except KeyError:
-            return [ 128, 128, 128 ]
+            return [128, 128, 128]
 
-    d = descr.split('<img>', 1)
-    d = d[-1] if len(d) > 1 else ''
-    d = d.strip().split(' ')[:height * width]
-    d = d + [ '<unk>' ] * (height * width - len(d))
-    d = [ token2color(t) for t in d ]
+    d = descr.split("<img>", 1)
+    d = d[-1] if len(d) > 1 else ""
+    d = d.strip().split(" ")[: height * width]
+    d = d + ["<unk>"] * (height * width - len(d))
+    d = [token2color(t) for t in d]
     img = torch.tensor(d).permute(1, 0)
     img = img.reshape(1, 3, height, width)
 
     return img
 
+
 ######################################################################
 
+
 def descr2properties(descr, height, width):
 
     if type(descr) == list:
-        return [ descr2properties(d, height, width) for d in descr ]
+        return [descr2properties(d, height, width) for d in descr]
 
-    d = descr.split('<img>', 1)
-    d = d[-1] if len(d) > 1 else ''
-    d = d.strip().split(' ')[:height * width]
+    d = descr.split("<img>", 1)
+    d = d[-1] if len(d) > 1 else ""
+    d = d.strip().split(" ")[: height * width]
 
     seen = {}
-    if len(d) != height * width: return []
+    if len(d) != height * width:
+        return []
 
     for k, x in enumerate(d):
         if x != color_names[0]:
             if x in color_tokens:
-                if x in seen: return []
+                if x in seen:
+                    return []
             else:
                 return []
             seen[x] = (color_id[x], k // width, k % width)
@@ -190,16 +291,19 @@ def descr2properties(descr, height, width):
 
     return s
 
+
 ######################################################################
 
+
 def nb_properties(descr, height, width):
     if type(descr) == list:
-        return [ nb_properties(d, height, width) for d in descr ]
+        return [nb_properties(d, height, width) for d in descr]
 
-    d = descr.split('<img>', 1)
-    if len(d) == 0: return 0
-    d = d[0].strip().split('<sep>')
-    d = [ x.strip() for x in d ]
+    d = descr.split("<img>", 1)
+    if len(d) == 0:
+        return 0
+    d = d[0].strip().split("<sep>")
+    d = [x.strip() for x in d]
 
     requested_properties = set(d)
     all_properties = set(descr2properties(descr, height, width))
@@ -207,30 +311,36 @@ def nb_properties(descr, height, width):
 
     return (len(requested_properties), len(all_properties), len(missing_properties))
 
+
 ######################################################################
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     descr = generate(
-        nb = 5, height = 12, width = 16,
-        pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s))
+        nb=5,
+        height=12,
+        width=16,
+        pruning_criterion=lambda s: not (
+            "green" in s and ("right" in s or "left" in s)
+        ),
     )
 
-    print(descr2properties(descr, height = 12, width = 16))
-    print(nb_properties(descr, height = 12, width = 16))
+    print(descr2properties(descr, height=12, width=16))
+    print(nb_properties(descr, height=12, width=16))
 
-    with open('picoclvr_example.txt', 'w') as f:
+    with open("picoclvr_example.txt", "w") as f:
         for d in descr:
-            f.write(f'{d}\n\n')
+            f.write(f"{d}\n\n")
 
-    img = descr2img(descr, height = 12, width = 16)
-    torchvision.utils.save_image(img / 255.,
-                                 'picoclvr_example.png', nrow = 16, pad_value = 0.8)
+    img = descr2img(descr, height=12, width=16)
+    torchvision.utils.save_image(
+        img / 255.0, "picoclvr_example.png", nrow=16, pad_value=0.8
+    )
 
     import time
 
     start_time = time.perf_counter()
-    descr = generate(nb = 1000, height = 12, width = 16)
+    descr = generate(nb=1000, height=12, width=16)
     end_time = time.perf_counter()
-    print(f'{len(descr) / (end_time - start_time):.02f} samples per second')
+    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
 
 ######################################################################