Update.
[mygptrnn.git] / maze.py
diff --git a/maze.py b/maze.py
index 8ac9fce..4953d10 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -231,9 +231,14 @@ def save_image(
             [0, 255, 0],  # start
             [127, 127, 255],  # goal
             [255, 0, 0],  # path
+            [128, 128, 128],  # error
         ]
     )
 
+    def safe_colors(x):
+        m = (x >= 0).long() * (x < colors.size(0) - 1).long()
+        return colors[x * m + (colors.size(0) - 1) * (1 - m)]
+
     mazes = mazes.cpu()
 
     c_mazes = (
@@ -256,7 +261,7 @@ def save_image(
     if predicted_paths is not None:
         predicted_paths = predicted_paths.cpu()
         c_predicted_paths = (
-            colors[predicted_paths.reshape(-1)]
+            safe_colors(predicted_paths.reshape(-1))
             .reshape(predicted_paths.size() + (-1,))
             .permute(0, 3, 1, 2)
         )
@@ -282,8 +287,6 @@ def save_image(
         -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
     ).clone()
 
-    print(f"{img.size()=} {imgs.size()=}")
-
     for k in range(imgs.size(1)):
         img[
             :,