Update.
[picoclvr.git] / grid.py
diff --git a/grid.py b/grid.py
index 70f7739..433cfd5 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -28,6 +28,7 @@ class GridFactory:
         self.height = height
         self.width = width
         self.max_nb_items = max_nb_items
+        self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
 
     def generate_scene(self):
@@ -44,8 +45,30 @@ class GridFactory:
             self.height, self.width
         )
 
-    def random_transformations(self):
+    def random_transformations(self, scene):
+        col, shp = scene
+        descriptions = []
         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+        transformations = torch.randint(5, (nb_transformations,))
+
+        for t in transformations:
+            if t == 0:
+                col, shp = col.flip(0), shp.flip(0)
+                descriptions += ["<chg> vertical flip"]
+            elif t == 1:
+                col, shp = col.flip(1), shp.flip(1)
+                descriptions += ["<chg> horizontal flip"]
+            elif t == 2:
+                col, shp = col.flip(0).t(), shp.flip(0).t()
+                descriptions += ["<chg> rotate 90 degrees"]
+            elif t == 3:
+                col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+                descriptions += ["<chg> rotate 180 degrees"]
+            elif t == 4:
+                col, shp = col.flip(1).t(), shp.flip(1).t()
+                descriptions += ["<chg> rotate 270 degrees"]
+
+        return (col.contiguous(), shp.contiguous()), descriptions
 
     def print_scene(self, scene):
         col, shp = scene
@@ -128,6 +151,8 @@ class GridFactory:
 
             start = self.grid_positions(scene)
 
+            scene, transformations = self.random_transformations(scene)
+
             for a in range(10):
                 col, shp = scene
                 col, shp = col.view(-1), shp.view(-1)
@@ -156,7 +181,9 @@ class GridFactory:
         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
 
         result = " ".join(
-            ["<obj> " + x for x in self.grid_positions(scene)] + questions
+            ["<obj> " + x for x in self.grid_positions(scene)]
+            + transformations
+            + questions
         )
 
         return scene, result