Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Jun 2023 05:48:36 +0000 (07:48 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Jun 2023 05:48:36 +0000 (07:48 +0200)
main.py

diff --git a/main.py b/main.py
index 0144817..784474f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+!/usr/bin/env python
 
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
@@ -622,7 +622,7 @@ class TaskMaze(Task):
     def compute_error(self, model, split="train", nb_to_use=-1):
         nb_total, nb_correct = 0, 0
         count = torch.zeros(
-            self.width * self.height, self.width * self.height, device=self.device
+            self.width * self.height, self.width * self.height, device=self.device, dtype=torch.int64
         )
         for input in task.batches(split, nb_to_use):
             result = input.clone()
@@ -676,6 +676,8 @@ class TaskMaze(Task):
             )
 
             if count is not None:
+                proportion_optimal = count.diagonal().sum().float() / count.sum()
+                log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
                 with open(
                     os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
                 ) as f: