Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 16:58:43 +0000 (18:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 16:58:43 +0000 (18:58 +0200)
grid.py
tasks.py

diff --git a/grid.py b/grid.py
index 70f7739..433cfd5 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -28,6 +28,7 @@ class GridFactory:
         self.height = height
         self.width = width
         self.max_nb_items = max_nb_items
+        self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
 
     def generate_scene(self):
@@ -44,8 +45,30 @@ class GridFactory:
             self.height, self.width
         )
 
-    def random_transformations(self):
+    def random_transformations(self, scene):
+        col, shp = scene
+        descriptions = []
         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+        transformations = torch.randint(5, (nb_transformations,))
+
+        for t in transformations:
+            if t == 0:
+                col, shp = col.flip(0), shp.flip(0)
+                descriptions += ["<chg> vertical flip"]
+            elif t == 1:
+                col, shp = col.flip(1), shp.flip(1)
+                descriptions += ["<chg> horizontal flip"]
+            elif t == 2:
+                col, shp = col.flip(0).t(), shp.flip(0).t()
+                descriptions += ["<chg> rotate 90 degrees"]
+            elif t == 3:
+                col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+                descriptions += ["<chg> rotate 180 degrees"]
+            elif t == 4:
+                col, shp = col.flip(1).t(), shp.flip(1).t()
+                descriptions += ["<chg> rotate 270 degrees"]
+
+        return (col.contiguous(), shp.contiguous()), descriptions
 
     def print_scene(self, scene):
         col, shp = scene
@@ -128,6 +151,8 @@ class GridFactory:
 
             start = self.grid_positions(scene)
 
+            scene, transformations = self.random_transformations(scene)
+
             for a in range(10):
                 col, shp = scene
                 col, shp = col.view(-1), shp.view(-1)
@@ -156,7 +181,9 @@ class GridFactory:
         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
 
         result = " ".join(
-            ["<obj> " + x for x in self.grid_positions(scene)] + questions
+            ["<obj> " + x for x in self.grid_positions(scene)]
+            + transformations
+            + questions
         )
 
         return scene, result
index c7348d5..0ab1823 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1429,7 +1429,7 @@ class Grid(Task):
     def tensorize(self, descr):
         token_descr = [s.strip().split(" ") for s in descr]
         l = max([len(s) for s in token_descr])
-        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
         id_descr = [[self.token2id[u] for u in s] for s in token_descr]
         return torch.tensor(id_descr, device=self.device)
 
@@ -1440,7 +1440,7 @@ class Grid(Task):
     # trim all the tensors in the tuple z to remove as much token from
     # left and right in the first tensor. If z is a tuple, all its
     # elements are trimed according to the triming for the first
-    def trim(self, z, token="<nul>"):
+    def trim(self, z, token="#"):
         n = self.token2id[token]
         if type(z) == tuple:
             x = z[0]
@@ -1483,7 +1483,7 @@ class Grid(Task):
         )
 
         # Build the tokenizer
-        tokens = {}
+        tokens = set()
         for d in [self.train_descr, self.test_descr]:
             for s in d:
                 for t in s.strip().split(" "):
@@ -1492,10 +1492,10 @@ class Grid(Task):
         # the same descr
         tokens = list(tokens)
         tokens.sort()
-        tokens = ["<nul>"] + tokens
+        tokens = ["#"] + tokens
         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
-        self.t_nul = self.token2id["<nul>"]
+        self.t_nul = self.token2id["#"]
         self.t_true = self.token2id["<true>"]
         self.t_false = self.token2id["<false>"]