Update.
[picoclvr.git] / tasks.py
index 0143ab2..cc3aea0 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -181,7 +181,9 @@ class SandBox(Task):
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
         )
 
-        if save_attention_image is not None:
+        if save_attention_image is None:
+            logger("no save_attention_image (is pycairo installed?)")
+        else:
             for k in range(10):
                 ns = torch.randint(self.test_input.size(0), (1,)).item()
                 input = self.test_input[ns : ns + 1].clone()
@@ -1167,7 +1169,9 @@ class RPL(Task):
             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
         )
 
-        if save_attention_image is not None:
+        if save_attention_image is None:
+            logger("no save_attention_image (is pycairo installed?)")
+        else:
             ns = torch.randint(self.test_input.size(0), (1,)).item()
             input = self.test_input[ns : ns + 1].clone()
             last = (input != self.t_nul).max(0).values.nonzero().max() + 3