def generate(nb, height, width,
max_nb_squares = 5, max_nb_properties = 10,
- many_colors = False):
+ nb_colors = 5,
+ pruning_criterion = None):
- nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares
+ assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
descr = [ ]
s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
+ if pruning_criterion is not None:
+ s = list(filter(pruning_criterion,s))
+
# pick at most max_nb_properties at random
nb_properties = torch.randint(max_nb_properties, (1,)) + 1
seen = {}
if len(d) != height * width: return []
+
for k, x in enumerate(d):
if x != color_names[0]:
if x in color_tokens:
return []
seen[x] = (color_id[x], k // width, k % width)
- square_c = torch.tensor( [ x[0] for x in seen.values() ] )
- square_i = torch.tensor( [ x[1] for x in seen.values() ] )
- square_j = torch.tensor( [ x[2] for x in seen.values() ] )
+ square_infos = tuple(zip(*seen.values()))
+ if square_infos:
+ square_c = torch.tensor(square_infos[0])
+ square_i = torch.tensor(square_infos[1])
+ square_j = torch.tensor(square_infos[2])
+ else:
+ square_c = torch.tensor([])
+ square_i = torch.tensor([])
+ square_j = torch.tensor([])
s = all_properties(height, width, len(seen), square_i, square_j, square_c)
######################################################################
-def nb_missing_properties(descr, height, width):
+def nb_properties(descr, height, width):
if type(descr) == list:
- return [ nb_missing_properties(d, height, width) for d in descr ]
+ return [ nb_properties(d, height, width) for d in descr ]
d = descr.split('<img>', 1)
if len(d) == 0: return 0
######################################################################
if __name__ == '__main__':
- descr = generate(nb = 5)
+ descr = generate(
+ nb = 5, height = 12, width = 16,
+ pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s))
+ )
- #print(descr2properties(descr))
- print(nb_missing_properties(descr))
+ print(descr2properties(descr, height = 12, width = 16))
+ print(nb_properties(descr, height = 12, width = 16))
with open('picoclvr_example.txt', 'w') as f:
for d in descr:
f.write(f'{d}\n\n')
- img = descr2img(descr)
+ img = descr2img(descr, height = 12, width = 16)
torchvision.utils.save_image(img / 255.,
'picoclvr_example.png', nrow = 16, pad_value = 0.8)
import time
start_time = time.perf_counter()
- descr = generate(nb = 1000)
+ descr = generate(nb = 1000, height = 12, width = 16)
end_time = time.perf_counter()
print(f'{len(descr) / (end_time - start_time):.02f} samples per second')