Update
[beaver.git] / maze.py
diff --git a/maze.py b/maze.py
index 6e8e179..81afcd9 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -61,11 +61,11 @@ def create_maze(h=11, w=17, nb_walls=8):
 ######################################################################
 
 
-def compute_distance(walls, i, j):
+def compute_distance(walls, goal_i, goal_j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 
-    dist[i, j] = 0
+    dist[goal_i, goal_j] = 0
     pred_dist = torch.empty_like(dist)
 
     while True:
@@ -93,15 +93,15 @@ def compute_distance(walls, i, j):
 ######################################################################
 
 
-def compute_policy(walls, i, j):
-    distance = compute_distance(walls, i, j)
+def compute_policy(walls, goal_i, goal_j):
+    distance = compute_distance(walls, goal_i, goal_j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0, :, 1:] = distance[:, :-1]
-    value[1, :, :-1] = distance[:, 1:]
-    value[2, 1:, :] = distance[:-1, :]
-    value[3, :-1, :] = distance[1:, :]
+    value[0, :, 1:] = distance[:, :-1]  # <
+    value[1, :, :-1] = distance[:, 1:]  # >
+    value[2, 1:, :] = distance[:-1, :]  # ^
+    value[3, :-1, :] = distance[1:, :]  # v
 
     proba = (value.min(dim=0)[0][None] == value).float()
     proba = proba / proba.sum(dim=0)[None]
@@ -110,6 +110,25 @@ def compute_policy(walls, i, j):
     return proba
 
 
+def stationary_densities(mazes, policies):
+    policies = policies * (mazes != v_goal)[:, None]
+    start = (mazes == v_start).nonzero(as_tuple=True)
+    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
+    pred_probas = probas.clone()
+    probas[start] = 1.0
+
+    while not pred_probas.equal(probas):
+        pred_probas.copy_(probas)
+        probas.zero_()
+        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
+        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
+        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
+        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
+        probas[start] = 1.0
+
+    return probas
+
+
 ######################################################################
 
 
@@ -193,6 +212,7 @@ def save_image(
     target_paths=None,
     predicted_paths=None,
     score_paths=None,
+    score_truth=None,
     path_correct=None,
 ):
     colors = torch.tensor(
@@ -200,41 +220,61 @@ def save_image(
             [255, 255, 255],  # empty
             [0, 0, 0],  # wall
             [0, 255, 0],  # start
-            [0, 0, 255],  # goal
+            [127, 127, 255],  # goal
             [255, 0, 0],  # path
         ]
     )
 
     mazes = mazes.cpu()
 
-    mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+    c_mazes = (
+        colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
+    )
 
-    imgs = mazes.unsqueeze(1)
+    if score_truth is not None:
+        score_truth = score_truth.cpu()
+        c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_truth = (
+            c_score_truth * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + (
+            mazes.unsqueeze(1) == v_empty
+        ) * c_score_truth
+
+    imgs = c_mazes.unsqueeze(1)
 
     if target_paths is not None:
         target_paths = target_paths.cpu()
 
-        target_paths = (
+        c_target_paths = (
             colors[target_paths.reshape(-1)]
             .reshape(target_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
 
-        imgs = torch.cat((imgs, target_paths.unsqueeze(1)), 1)
+        imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
 
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
-        predicted_paths = (
+        c_predicted_paths = (
             colors[predicted_paths.reshape(-1)]
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
-        imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
+        imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
 
     if score_paths is not None:
-        score_paths = (score_paths.cpu() * 255.0).long()
-        score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
-        imgs = torch.cat((imgs, score_paths.unsqueeze(1)), 1)
+        score_paths = score_paths.cpu()
+        c_score_paths = score_paths.unsqueeze(1).expand(-1, 3, -1, -1)
+        c_score_paths = (
+            c_score_paths * colors[4].reshape(1, 3, 1, 1)
+            + (1 - c_score_paths) * colors[0].reshape(1, 3, 1, 1)
+        ).long()
+        c_score_paths = c_score_paths * (mazes.unsqueeze(1) == v_empty) + c_mazes * (
+            mazes.unsqueeze(1) != v_empty
+        )
+        imgs = torch.cat((imgs, c_score_paths.unsqueeze(1)), 1)
 
     # NxKxCxHxW
     if path_correct is None: