Fixed a bug when there are no squares.
[mygpt.git] / picoclvr.py
index 437439e..3ecbf3a 100755 (executable)
@@ -93,11 +93,11 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c):
 
 ######################################################################
 
-def generate(nb, height = 6, width = 8,
+def generate(nb, height, width,
              max_nb_squares = 5, max_nb_properties = 10,
-             many_colors = False):
+             nb_colors = 5):
 
-    nb_colors =  len(color_tokens) - 1 if many_colors else max_nb_squares
+    assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
 
     descr = [ ]
 
@@ -129,7 +129,7 @@ def generate(nb, height = 6, width = 8,
 
 ######################################################################
 
-def descr2img(descr, height = 6, width = 8):
+def descr2img(descr, height, width):
 
     if type(descr) == list:
         return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
@@ -152,7 +152,7 @@ def descr2img(descr, height = 6, width = 8):
 
 ######################################################################
 
-def descr2properties(descr, height = 6, width = 8):
+def descr2properties(descr, height, width):
 
     if type(descr) == list:
         return [ descr2properties(d, height, width) for d in descr ]
@@ -163,6 +163,7 @@ def descr2properties(descr, height = 6, width = 8):
 
     seen = {}
     if len(d) != height * width: return []
+
     for k, x in enumerate(d):
         if x != color_names[0]:
             if x in color_tokens:
@@ -171,9 +172,15 @@ def descr2properties(descr, height = 6, width = 8):
                 return []
             seen[x] = (color_id[x], k // width, k % width)
 
-    square_c = torch.tensor( [ x[0] for x in seen.values() ] )
-    square_i = torch.tensor( [ x[1] for x in seen.values() ] )
-    square_j = torch.tensor( [ x[2] for x in seen.values() ] )
+    square_infos = tuple(zip(*seen.values()))
+    if square_infos:
+        square_c = torch.tensor(square_infos[0])
+        square_i = torch.tensor(square_infos[1])
+        square_j = torch.tensor(square_infos[2])
+    else:
+        square_c = torch.tensor([])
+        square_i = torch.tensor([])
+        square_j = torch.tensor([])
 
     s = all_properties(height, width, len(seen), square_i, square_j, square_c)
 
@@ -181,9 +188,28 @@ def descr2properties(descr, height = 6, width = 8):
 
 ######################################################################
 
+def nb_properties(descr, height, width):
+    if type(descr) == list:
+        return [ nb_properties(d, height, width) for d in descr ]
+
+    d = descr.split('<img>', 1)
+    if len(d) == 0: return 0
+    d = d[0].strip().split('<sep>')
+    d = [ x.strip() for x in d ]
+
+    requested_properties = set(d)
+    all_properties = set(descr2properties(descr, height, width))
+    missing_properties = requested_properties - all_properties
+
+    return (len(requested_properties), len(all_properties), len(missing_properties))
+
+######################################################################
+
 if __name__ == '__main__':
     descr = generate(nb = 5)
-    print(descr2properties(descr))
+
+    #print(descr2properties(descr))
+    print(nb_properties(descr))
 
     with open('picoclvr_example.txt', 'w') as f:
         for d in descr:
@@ -196,7 +222,7 @@ if __name__ == '__main__':
     import time
 
     start_time = time.perf_counter()
-    descr = generate(10000)
+    descr = generate(nb = 1000)
     end_time = time.perf_counter()
     print(f'{len(descr) / (end_time - start_time):.02f} samples per second')