projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
maze.py
diff --git
a/maze.py
b/maze.py
index
8ac9fce
..
4953d10
100755
(executable)
--- a/
maze.py
+++ b/
maze.py
@@
-231,9
+231,14
@@
def save_image(
[0, 255, 0], # start
[127, 127, 255], # goal
[255, 0, 0], # path
[0, 255, 0], # start
[127, 127, 255], # goal
[255, 0, 0], # path
+ [128, 128, 128], # error
]
)
]
)
+ def safe_colors(x):
+ m = (x >= 0).long() * (x < colors.size(0) - 1).long()
+ return colors[x * m + (colors.size(0) - 1) * (1 - m)]
+
mazes = mazes.cpu()
c_mazes = (
mazes = mazes.cpu()
c_mazes = (
@@
-256,7
+261,7
@@
def save_image(
if predicted_paths is not None:
predicted_paths = predicted_paths.cpu()
c_predicted_paths = (
if predicted_paths is not None:
predicted_paths = predicted_paths.cpu()
c_predicted_paths = (
- colors[predicted_paths.reshape(-1)]
+ safe_colors(predicted_paths.reshape(-1))
.reshape(predicted_paths.size() + (-1,))
.permute(0, 3, 1, 2)
)
.reshape(predicted_paths.size() + (-1,))
.permute(0, 3, 1, 2)
)
@@
-282,8
+287,6
@@
def save_image(
-1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
).clone()
-1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
).clone()
- print(f"{img.size()=} {imgs.size()=}")
-
for k in range(imgs.size(1)):
img[
:,
for k in range(imgs.size(1)):
img[
:,