self.height = height
self.width = width
self.max_nb_items = max_nb_items
+ self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
def generate_scene(self):
self.height, self.width
)
- def random_transformations(self):
+ def random_transformations(self, scene):
+ col, shp = scene
+ descriptions = []
nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+ transformations = torch.randint(5, (nb_transformations,))
+
+ for t in transformations:
+ if t == 0:
+ col, shp = col.flip(0), shp.flip(0)
+ descriptions += ["<chg> vertical flip"]
+ elif t == 1:
+ col, shp = col.flip(1), shp.flip(1)
+ descriptions += ["<chg> horizontal flip"]
+ elif t == 2:
+ col, shp = col.flip(0).t(), shp.flip(0).t()
+ descriptions += ["<chg> rotate 90 degrees"]
+ elif t == 3:
+ col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+ descriptions += ["<chg> rotate 180 degrees"]
+ elif t == 4:
+ col, shp = col.flip(1).t(), shp.flip(1).t()
+ descriptions += ["<chg> rotate 270 degrees"]
+
+ return (col.contiguous(), shp.contiguous()), descriptions
def print_scene(self, scene):
col, shp = scene
start = self.grid_positions(scene)
+ scene, transformations = self.random_transformations(scene)
+
for a in range(10):
col, shp = scene
col, shp = col.view(-1), shp.view(-1)
questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
result = " ".join(
- ["<obj> " + x for x in self.grid_positions(scene)] + questions
+ ["<obj> " + x for x in self.grid_positions(scene)]
+ + transformations
+ + questions
)
return scene, result