Update.
[picoclvr.git] / tasks.py
index 0143ab2..5019aed 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -181,7 +181,11 @@ class SandBox(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
-        if save_attention_image is not None:
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+        if save_attention_image is None:
+            logger("no save_attention_image (is pycairo installed?)")
+        else:
             for k in range(10):
                 ns = torch.randint(self.test_input.size(0), (1,)).item()
                 input = self.test_input[ns : ns + 1].clone()
@@ -369,6 +373,10 @@ class PicoCLVR(Task):
             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
         )
 
+        logger(
+            f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
+        )
+
     ######################################################################
 
     def produce_results(
@@ -639,6 +647,8 @@ class Maze(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         if count is not None:
             proportion_optimal = count.diagonal().sum().float() / count.sum()
             logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
@@ -778,6 +788,8 @@ class Snake(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
 
 ######################################################################
 
@@ -887,6 +899,8 @@ class Stack(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         ##############################################################
         # Log a few generated sequences
         input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
@@ -1159,6 +1173,8 @@ class RPL(Task):
                 f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
             )
 
+            logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
+
         test_nb_total, test_nb_errors = compute_nb_errors_output(
             self.test_input[:1000].to(self.device), nb_to_log=10
         )
@@ -1167,7 +1183,9 @@ class RPL(Task):
             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
         )
 
-        if save_attention_image is not None:
+        if save_attention_image is None:
+            logger("no save_attention_image (is pycairo installed?)")
+        else:
             ns = torch.randint(self.test_input.size(0), (1,)).item()
             input = self.test_input[ns : ns + 1].clone()
             last = (input != self.t_nul).max(0).values.nonzero().max() + 3
@@ -1355,6 +1373,8 @@ class Expr(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
         nb_total = test_nb_delta.sum() + test_nb_missed
         for d in range(test_nb_delta.size(0)):
             logger(