X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=f097eb027601b82525e02b2622c378740e633c52;hb=e68f19634d3282e39a488d146480b19bb23e8652;hp=774ae3b6af3bd45da09713fa249ab611ece8e580;hpb=b5efc396f45c23b7de0fe11f618731ac2b900d99;p=mygpt.git
diff --git a/picoclvr.py b/picoclvr.py
index 774ae3b..f097eb0 100755
--- a/picoclvr.py
+++ b/picoclvr.py
@@ -67,6 +67,7 @@ color_names = [
'azure', 'snow', 'silver', 'gainsboro', 'white_smoke',
]
+color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] )
color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
######################################################################
@@ -92,11 +93,11 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c):
######################################################################
-def generate(nb, height = 6, width = 8,
+def generate(nb, height, width,
max_nb_squares = 5, max_nb_properties = 10,
- many_colors = False):
+ nb_colors = 5):
- nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares
+ assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
descr = [ ]
@@ -128,10 +129,10 @@ def generate(nb, height = 6, width = 8,
######################################################################
-def descr2img(descr, height = 6, width = 8):
+def descr2img(descr, height, width):
if type(descr) == list:
- return torch.cat([ descr2img(d) for d in descr ], 0)
+ return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
def token2color(t):
try:
@@ -151,9 +152,60 @@ def descr2img(descr, height = 6, width = 8):
######################################################################
+def descr2properties(descr, height, width):
+
+ if type(descr) == list:
+ return [ descr2properties(d, height, width) for d in descr ]
+
+ d = descr.split('', 1)
+ d = d[-1] if len(d) > 1 else ''
+ d = d.strip().split(' ')[:height * width]
+
+ seen = {}
+ if len(d) != height * width: return []
+
+ for k, x in enumerate(d):
+ if x != color_names[0]:
+ if x in color_tokens:
+ if x in seen: return []
+ else:
+ return []
+ seen[x] = (color_id[x], k // width, k % width)
+
+ square_infos = zip(*seen.values())
+ square_c = torch.tensor(square_infos[0])
+ square_i = torch.tensor(square_infos[1])
+ square_j = torch.tensor(square_infos[2])
+
+ s = all_properties(height, width, len(seen), square_i, square_j, square_c)
+
+ return s
+
+######################################################################
+
+def nb_properties(descr, height, width):
+ if type(descr) == list:
+ return [ nb_properties(d, height, width) for d in descr ]
+
+ d = descr.split('', 1)
+ if len(d) == 0: return 0
+ d = d[0].strip().split('')
+ d = [ x.strip() for x in d ]
+
+ requested_properties = set(d)
+ all_properties = set(descr2properties(descr, height, width))
+ missing_properties = requested_properties - all_properties
+
+ return (len(requested_properties), len(all_properties), len(missing_properties))
+
+######################################################################
+
if __name__ == '__main__':
descr = generate(nb = 5)
+ #print(descr2properties(descr))
+ print(nb_properties(descr))
+
with open('picoclvr_example.txt', 'w') as f:
for d in descr:
f.write(f'{d}\n\n')
@@ -165,7 +217,7 @@ if __name__ == '__main__':
import time
start_time = time.perf_counter()
- descr = generate(10000)
+ descr = generate(nb = 1000)
end_time = time.perf_counter()
print(f'{len(descr) / (end_time - start_time):.02f} samples per second')