- nb = torch.randint(5, (1,)) + 1
- shape_position = torch.randperm(height * width)[:nb]
- shape_c = torch.randperm(5)[:nb] + 1
+ nb_shapes = torch.randint(len(color_tokens) - 1, (1,)) + 1
+ shape_position = torch.randperm(height * width)[:nb_shapes]
+ shape_c = torch.randperm(len(color_tokens) - 1)[:nb_shapes] + 1