Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 20:33:48 +0000 (22:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 20:33:48 +0000 (22:33 +0200)
grid.py
tasks.py

diff --git a/grid.py b/grid.py
index f72c8e3..5b28914 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -19,7 +19,7 @@ name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
 class GridFactory:
     def __init__(
         self,
-        size=4,
+        size=6,
         max_nb_items=4,
         max_nb_transformations=3,
         nb_questions=4,
@@ -143,14 +143,14 @@ class GridFactory:
     def generate_scene_and_questions(self):
         while True:
             while True:
-                scene = self.generate_scene()
-                true = self.all_properties(scene)
+                start_scene = self.generate_scene()
+                true = self.all_properties(start_scene)
                 if len(true) >= self.nb_questions:
                     break
 
-            start = self.grid_positions(scene)
+            start = self.grid_positions(start_scene)
 
-            scene, transformations = self.random_transformations(scene)
+            scene, transformations = self.random_transformations(start_scene)
 
             # transformations=[]
 
@@ -185,7 +185,7 @@ class GridFactory:
             + questions
         )
 
-        return scene, result
+        return start_scene, scene, result
 
     def generate_samples(self, nb, progress_bar=None):
         result = []
@@ -195,7 +195,7 @@ class GridFactory:
             r = progress_bar(r)
 
         for _ in r:
-            result.append(self.generate_scene_and_questions()[1])
+            result.append(self.generate_scene_and_questions()[2])
 
         return result
 
@@ -207,13 +207,17 @@ if __name__ == "__main__":
 
     grid_factory = GridFactory()
 
-    start_time = time.perf_counter()
-    samples = grid_factory.generate_samples(10000)
-    end_time = time.perf_counter()
-    print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+    start_time = time.perf_counter()
+    samples = grid_factory.generate_samples(10000)
+    end_time = time.perf_counter()
+    print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
 
-    scene, questions = grid_factory.generate_scene_and_questions()
+    start_scene, scene, questions = grid_factory.generate_scene_and_questions()
+    print("-- Original scene -----------------------------")
+    grid_factory.print_scene(start_scene)
+    print("-- Transformed scene --------------------------")
     grid_factory.print_scene(scene)
+    print("-- Sequence -----------------------------------")
     print(questions)
 
 ######################################################################
index 2c2f914..24c13fe 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1539,8 +1539,8 @@ class Grid(Task):
         nb_total = ar_mask.sum().item()
         nb_correct = ((correct == result).long() * ar_mask).sum().item()
 
-        logger(f"test_performance {nb_total=} {nb_correct=}")
-        logger(f"main_test_accuracy {nb_correct / nb_total}")
+        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
 
 
 ######################################################################