Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 22:47:09 +0000 (00:47 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 22:47:09 +0000 (00:47 +0200)
grid.py
tasks.py

diff --git a/grid.py b/grid.py
index 5b28914..60baedf 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -24,6 +24,7 @@ class GridFactory:
         max_nb_transformations=3,
         nb_questions=4,
     ):
+        assert size % 2 == 0
         self.size = size
         self.max_nb_items = max_nb_items
         self.max_nb_transformations = max_nb_transformations
@@ -137,6 +138,8 @@ class GridFactory:
                                     properties += [f"a {n1} is right of a {n2}"]
                                 if j1 < j2:
                                     properties += [f"a {n1} is left of a {n2}"]
+                                if abs(i1 - i2) + abs(j1 - j2) == 1:
+                                    properties += [f"a {n1} is next to a {n2}"]
 
         return properties
 
@@ -144,16 +147,11 @@ class GridFactory:
         while True:
             while True:
                 start_scene = self.generate_scene()
-                true = self.all_properties(start_scene)
+                scene, transformations = self.random_transformations(start_scene)
+                true = self.all_properties(scene)
                 if len(true) >= self.nb_questions:
                     break
 
-            start = self.grid_positions(start_scene)
-
-            scene, transformations = self.random_transformations(start_scene)
-
-            # transformations=[]
-
             for a in range(10):
                 col, shp = scene
                 col, shp = col.view(-1), shp.view(-1)
@@ -163,8 +161,17 @@ class GridFactory:
                     col.view(self.size, self.size),
                     shp.view(self.size, self.size),
                 )
-                # other_scene = self.generate_scene()
-                false = list(set(self.all_properties(other_scene)) - set(true))
+
+                false = self.all_properties(other_scene)
+
+                # We sometime add properties from a totally different
+                # scene to have negative "there is a xxx xxx"
+                # properties
+                if torch.rand(1).item() < 0.2:
+                    other_scene = self.generate_scene()
+                    false += self.all_properties(other_scene)
+
+                false = list(set(false) - set(true))
                 if len(false) >= self.nb_questions:
                     break
 
@@ -173,14 +180,14 @@ class GridFactory:
 
         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
-        true = ["<prop> " + q + " <true>" for q in true]
-        false = ["<prop> " + q + " <false>" for q in false]
+        true = ["<prop> " + q + " <ans> true" for q in true]
+        false = ["<prop> " + q + " <ans> false" for q in false]
 
         union = true + false
         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
 
         result = " ".join(
-            ["<obj> " + x for x in self.grid_positions(scene)]
+            ["<obj> " + x for x in self.grid_positions(start_scene)]
             + transformations
             + questions
         )
index 24c13fe..cbc8e6b 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1495,8 +1495,8 @@ class Grid(Task):
         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
         self.t_nul = self.token2id["#"]
-        self.t_true = self.token2id["<true>"]
-        self.t_false = self.token2id["<false>"]
+        self.t_true = self.token2id["true"]
+        self.t_false = self.token2id["false"]
 
         # Tokenize the train and test sets
         self.train_input = self.tensorize(self.train_descr)