From: Francois Fleuret Date: Sat, 16 Jul 2022 08:14:21 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=5bc2d741ea7aac83005f099665b47f8a090931cb Update. --- diff --git a/picoclvr.py b/picoclvr.py index 774ae3b..437439e 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -67,6 +67,7 @@ color_names = [ 'azure', 'snow', 'silver', 'gainsboro', 'white_smoke', ] +color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] ) color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) ###################################################################### @@ -131,7 +132,7 @@ def generate(nb, height = 6, width = 8, def descr2img(descr, height = 6, width = 8): if type(descr) == list: - return torch.cat([ descr2img(d) for d in descr ], 0) + return torch.cat([ descr2img(d, height, width) for d in descr ], 0) def token2color(t): try: @@ -151,8 +152,38 @@ def descr2img(descr, height = 6, width = 8): ###################################################################### +def descr2properties(descr, height = 6, width = 8): + + if type(descr) == list: + return [ descr2properties(d, height, width) for d in descr ] + + d = descr.split('', 1) + d = d[-1] if len(d) > 1 else '' + d = d.strip().split(' ')[:height * width] + + seen = {} + if len(d) != height * width: return [] + for k, x in enumerate(d): + if x != color_names[0]: + if x in color_tokens: + if x in seen: return [] + else: + return [] + seen[x] = (color_id[x], k // width, k % width) + + square_c = torch.tensor( [ x[0] for x in seen.values() ] ) + square_i = torch.tensor( [ x[1] for x in seen.values() ] ) + square_j = torch.tensor( [ x[2] for x in seen.values() ] ) + + s = all_properties(height, width, len(seen), square_i, square_j, square_c) + + return s + +###################################################################### + if __name__ == '__main__': descr = generate(nb = 5) + print(descr2properties(descr)) with open('picoclvr_example.txt', 'w') as f: for d in descr: