Update.
[picoclvr.git] / tasks.py
index cc3aea0..5019aed 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -181,6 +181,8 @@ class SandBox(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         if save_attention_image is None:
             logger("no save_attention_image (is pycairo installed?)")
         else:
@@ -371,6 +373,10 @@ class PicoCLVR(Task):
             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
         )
 
+        logger(
+            f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
+        )
+
     ######################################################################
 
     def produce_results(
@@ -641,6 +647,8 @@ class Maze(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         if count is not None:
             proportion_optimal = count.diagonal().sum().float() / count.sum()
             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
@@ -780,6 +788,8 @@ class Snake(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
 
 ######################################################################
 
@@ -889,6 +899,8 @@ class Stack(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         ##############################################################
         # Log a few generated sequences
         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
@@ -1161,6 +1173,8 @@ class RPL(Task):
                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
             )
 
+            logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
+
         test_nb_total, test_nb_errors = compute_nb_errors_output(
             self.test_input[:1000].to(self.device), nb_to_log=10
         )
@@ -1359,6 +1373,8 @@ class Expr(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         nb_total = test_nb_delta.sum() + test_nb_missed
         for d in range(test_nb_delta.size(0)):
             logger(