projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
aaa1fc7
)
Update.
author
François Fleuret
<francois@fleuret.org>
Mon, 6 Nov 2023 07:06:33 +0000
(08:06 +0100)
committer
François Fleuret
<francois@fleuret.org>
Mon, 6 Nov 2023 07:06:33 +0000
(08:06 +0100)
picocrafter.py
patch
|
blob
|
history
diff --git
a/picocrafter.py
b/picocrafter.py
index
e303554
..
36088ac
100755
(executable)
--- a/
picocrafter.py
+++ b/
picocrafter.py
@@
-96,19
+96,19
@@
class PicroCrafterEngine:
world_height=27,
world_width=27,
nb_walls=27,
world_height=27,
world_width=27,
nb_walls=27,
- margin=2,
+
world_
margin=2,
view_height=5,
view_width=5,
device=torch.device("cpu"),
):
view_height=5,
view_width=5,
device=torch.device("cpu"),
):
- assert (world_height - 2 *
margin) % (view_height - 2 *
margin) == 0
- assert (world_width - 2 *
margin) % (view_width - 2 *
margin) == 0
+ assert (world_height - 2 *
world_margin) % (view_height - 2 * world_
margin) == 0
+ assert (world_width - 2 *
world_margin) % (view_width - 2 * world_
margin) == 0
self.device = device
self.world_height = world_height
self.world_width = world_width
self.device = device
self.world_height = world_height
self.world_width = world_width
- self.
margin =
margin
+ self.
world_margin = world_
margin
self.view_height = view_height
self.view_width = view_width
self.nb_walls = nb_walls
self.view_height = view_height
self.view_width = view_width
self.nb_walls = nb_walls
@@
-168,7
+168,11
@@
class PicroCrafterEngine:
def reset(self, nb_agents):
self.worlds = self.create_worlds(
def reset(self, nb_agents):
self.worlds = self.create_worlds(
- nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin
+ nb_agents,
+ self.world_height,
+ self.world_width,
+ self.nb_walls,
+ self.world_margin,
).to(self.device)
self.life_level_in_100th = torch.full(
(nb_agents,), self.life_level_max * 100 + 99, device=self.device
).to(self.device)
self.life_level_in_100th = torch.full(
(nb_agents,), self.life_level_max * 100 + 99, device=self.device
@@
-209,9
+213,11
@@
class PicroCrafterEngine:
return m
return m
- def create_worlds(self, nb, height, width, nb_walls, margin=2):
- margin -= 1 # The maze adds a wall all around
- m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls)
+ def create_worlds(self, nb, height, width, nb_walls, world_margin=2):
+ world_margin -= 1 # The maze adds a wall all around
+ m = self.create_mazes(
+ nb, height - 2 * world_margin, width - 2 * world_margin, nb_walls
+ )
q = m.flatten(1)
z = "@aAbBcC$$$$$" # What to add to the maze
u = torch.rand(q.size(), device=q.device) * (1 - q)
q = m.flatten(1)
z = "@aAbBcC$$$$$" # What to add to the maze
u = torch.rand(q.size(), device=q.device) * (1 - q)
@@
-222,12
+228,12
@@
class PicroCrafterEngine:
torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
- if margin > 0:
+ if
world_
margin > 0:
r = m.new_full(
r = m.new_full(
- (m.size(0), m.size(1) +
margin * 2, m.size(2) +
margin * 2),
+ (m.size(0), m.size(1) +
world_margin * 2, m.size(2) + world_
margin * 2),
self.tile2id["+"],
)
self.tile2id["+"],
)
- r[:,
margin:-margin, margin:-
margin] = m
+ r[:,
world_margin:-world_margin, world_margin:-world_
margin] = m
m = r
return m
m = r
return m
@@
-378,12
+384,12
@@
class PicroCrafterEngine:
def views(self):
i_height, i_width = (
def views(self):
i_height, i_width = (
- self.view_height - 2 * self.margin,
- self.view_width - 2 * self.margin,
+ self.view_height - 2 * self.
world_
margin,
+ self.view_width - 2 * self.
world_
margin,
)
a = (self.worlds == self.tile2id["@"]).nonzero()
)
a = (self.worlds == self.tile2id["@"]).nonzero()
- y = i_height * ((a[:, 1] - self.margin) // i_height)
- x = i_width * ((a[:, 2] - self.margin) // i_width)
+ y = i_height * ((a[:, 1] - self.
world_
margin) // i_height)
+ x = i_width * ((a[:, 2] - self.
world_
margin) // i_width)
n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
i = (
torch.arange(self.view_height, device=a.device)[None, :, None]
n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width)
i = (
torch.arange(self.view_height, device=a.device)[None, :, None]
@@
-414,7
+420,7
@@
class PicroCrafterEngine:
return v
return v
- def seq2tile
pic(self, t, width
):
+ def seq2tile
s(self, t, width=None
):
def tile(n):
n = n.item()
if n in self.id2tile:
def tile(n):
n = n.item()
if n in self.id2tile:
@@
-423,7
+429,10
@@
class PicroCrafterEngine:
return "?"
if t.dim() == 2:
return "?"
if t.dim() == 2:
- return [self.seq2tilepic(r, width) for r in t]
+ return [self.seq2tiles(r, width) for r in t]
+
+ if width is None:
+ width = self.view_width
t = t.reshape(-1, width)
t = t.reshape(-1, width)
@@
-458,7
+467,7
@@
if __name__ == "__main__":
nb_walls=35,
view_height=9,
view_width=9,
nb_walls=35,
view_height=9,
view_width=9,
- margin=4,
+
world_
margin=4,
device=device,
)
device=device,
)
@@
-478,7
+487,7
@@
if __name__ == "__main__":
to_print = ""
os.system("clear")
to_print = ""
os.system("clear")
- l = engine.seq2tile
pic
(engine.worlds.flatten(1), width=engine.world_width)
+ l = engine.seq2tile
s
(engine.worlds.flatten(1), width=engine.world_width)
to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
to_print += char_conv(fusion_multi_lines(l)) + "\n\n"
@@
-488,7
+497,7
@@
if __name__ == "__main__":
rewards, inventories, life_levels = engine.step(action)
if display:
rewards, inventories, life_levels = engine.step(action)
if display:
- l = engine.seq2tile
pic(views.flatten(1), engine.view_width
)
+ l = engine.seq2tile
s(views.flatten(1)
)
l = [
v + [f"{engine.action2str(a.item())}/{r: 3d}"]
for (v, a, r) in zip(l, action, rewards)
l = [
v + [f"{engine.action2str(a.item())}/{r: 3d}"]
for (v, a, r) in zip(l, action, rewards)