X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;fp=picoclvr.py;h=f097eb027601b82525e02b2622c378740e633c52;hb=e68f19634d3282e39a488d146480b19bb23e8652;hp=c4550723965cb85f6b786a44e5f11854a859af7d;hpb=52c6bd98650c846459f10e8303dd2e6c7ba2a68f;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index c455072..f097eb0 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -163,6 +163,7 @@ def descr2properties(descr, height, width): seen = {} if len(d) != height * width: return [] + for k, x in enumerate(d): if x != color_names[0]: if x in color_tokens: @@ -171,9 +172,10 @@ def descr2properties(descr, height, width): return [] seen[x] = (color_id[x], k // width, k % width) - square_c = torch.tensor( [ x[0] for x in seen.values() ] ) - square_i = torch.tensor( [ x[1] for x in seen.values() ] ) - square_j = torch.tensor( [ x[2] for x in seen.values() ] ) + square_infos = zip(*seen.values()) + square_c = torch.tensor(square_infos[0]) + square_i = torch.tensor(square_infos[1]) + square_j = torch.tensor(square_infos[2]) s = all_properties(height, width, len(seen), square_i, square_j, square_c)