X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=19517afaa05ea7c66eecdf7a1431cc6fe48e04f3;hb=ceda7771b579aa3fb21115c6e71975d3cb7583bd;hp=712da1760764b85d3639d7427e16d85a9553bb4b;hpb=046f35f38d629c9854104e855a53f0142449138f;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index 712da17..19517af 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -67,12 +67,34 @@ color_names = [ 'azure', 'snow', 'silver', 'gainsboro', 'white_smoke', ] +color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] ) color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) ###################################################################### -def generate(nb, height = 6, width = 8, - max_nb_squares = 5, max_nb_statements = 10, +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [ ] + + for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: + s += [ f'there is {c}' ] + + if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] + if square_i[r] < height//3: s += [ f'{c} top' ] + if square_j[r] >= width - width//3: s += [ f'{c} right' ] + if square_j[r] < width//3: s += [ f'{c} left' ] + + for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: + if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] + if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] + if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] + if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] + + return s + +###################################################################### + +def generate(nb, height, width, + max_nb_squares = 5, max_nb_properties = 10, many_colors = False): nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares @@ -83,6 +105,7 @@ def generate(nb, height = 6, width = 8, nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] + # color 0 is white and reserved for the background square_c = torch.randperm(nb_colors)[:nb_squares] + 1 square_i = square_position.div(width, rounding_mode = 'floor') square_j = square_position % width @@ -90,28 +113,14 @@ def generate(nb, height = 6, width = 8, img = [ 0 ] * height * width for k in range(nb_squares): img[square_position[k]] = square_c[k] - # generates all the true relations - - s = [ ] + # generates all the true properties - for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - s += [ f'there is {c}' ] + s = all_properties(height, width, nb_squares, square_i, square_j, square_c) - if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] - if square_i[r] < height//3: s += [ f'{c} top' ] - if square_j[r] >= width - width//3: s += [ f'{c} right' ] - if square_j[r] < width//3: s += [ f'{c} left' ] + # pick at most max_nb_properties at random - for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] - if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] - if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] - if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] - - # pick at most max_nb_statements at random - - nb_statements = torch.randint(max_nb_statements, (1,)) + 1 - s = ' '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] ) + nb_properties = torch.randint(max_nb_properties, (1,)) + 1 + s = ' '.join([ s[k] for k in torch.randperm(len(s))[:nb_properties] ] ) s += ' ' + ' '.join([ f'{color_names[n]}' for n in img ]) descr += [ s ] @@ -120,7 +129,10 @@ def generate(nb, height = 6, width = 8, ###################################################################### -def descr2img(descr, height = 6, width = 8): +def descr2img(descr, height, width): + + if type(descr) == list: + return torch.cat([ descr2img(d, height, width) for d in descr ], 0) def token2color(t): try: @@ -128,38 +140,82 @@ def descr2img(descr, height = 6, width = 8): except KeyError: return [ 128, 128, 128 ] - def img_descr(x): - u = x.split('', 1) - return u[1] if len(u) > 1 else '' - - img = torch.full((len(descr), 3, height, width), 255) - d = [ img_descr(x) for x in descr ] - d = [ u.strip().split(' ')[:height * width] for u in d ] - d = [ u + [ '' ] * (height * width - len(u)) for u in d ] - d = [ [ token2color(t) for t in u ] for u in d ] - img = torch.tensor(d).permute(0, 2, 1) - img = img.reshape(img.size(0), 3, height, width) + d = descr.split('', 1) + d = d[-1] if len(d) > 1 else '' + d = d.strip().split(' ')[:height * width] + d = d + [ '' ] * (height * width - len(d)) + d = [ token2color(t) for t in d ] + img = torch.tensor(d).permute(1, 0) + img = img.reshape(1, 3, height, width) return img ###################################################################### +def descr2properties(descr, height, width): + + if type(descr) == list: + return [ descr2properties(d, height, width) for d in descr ] + + d = descr.split('', 1) + d = d[-1] if len(d) > 1 else '' + d = d.strip().split(' ')[:height * width] + + seen = {} + if len(d) != height * width: return [] + for k, x in enumerate(d): + if x != color_names[0]: + if x in color_tokens: + if x in seen: return [] + else: + return [] + seen[x] = (color_id[x], k // width, k % width) + + square_c = torch.tensor( [ x[0] for x in seen.values() ] ) + square_i = torch.tensor( [ x[1] for x in seen.values() ] ) + square_j = torch.tensor( [ x[2] for x in seen.values() ] ) + + s = all_properties(height, width, len(seen), square_i, square_j, square_c) + + return s + +###################################################################### + +def nb_missing_properties(descr, height, width): + if type(descr) == list: + return [ nb_missing_properties(d, height, width) for d in descr ] + + d = descr.split('', 1) + if len(d) == 0: return 0 + d = d[0].strip().split('') + d = [ x.strip() for x in d ] + + requested_properties = set(d) + all_properties = set(descr2properties(descr, height, width)) + missing_properties = requested_properties - all_properties + + return (len(requested_properties), len(all_properties), len(missing_properties)) + +###################################################################### + if __name__ == '__main__': descr = generate(nb = 5) - for d in descr: - print(d) - print() - img = descr2img(descr) - print(img.size()) + #print(descr2properties(descr)) + print(nb_missing_properties(descr)) + + with open('picoclvr_example.txt', 'w') as f: + for d in descr: + f.write(f'{d}\n\n') + img = descr2img(descr) torchvision.utils.save_image(img / 255., 'picoclvr_example.png', nrow = 16, pad_value = 0.8) import time start_time = time.perf_counter() - descr = generate(10000) + descr = generate(nb = 1000) end_time = time.perf_counter() print(f'{len(descr) / (end_time - start_time):.02f} samples per second')