projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
7cf92d1
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 31 Oct 2023 16:55:45 +0000
(17:55 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 31 Oct 2023 16:55:45 +0000
(17:55 +0100)
picocrafter.py
patch
|
blob
|
history
diff --git
a/picocrafter.py
b/picocrafter.py
index
31ba1e4
..
7810b67
100755
(executable)
--- a/
picocrafter.py
+++ b/
picocrafter.py
@@
-79,7
+79,7
@@
class PicroCrafterEngine:
self.reward_per_hit = -1
self.reward_death = -10
self.reward_per_hit = -1
self.reward_death = -10
- self.tokens = " +#@$aAbBcC"
+ self.tokens = " +#@$aAbBcC
.
"
self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)])
@@
-92,7
+92,7
@@
class PicroCrafterEngine:
("b", "B"),
("B", "c"),
("c", "C"),
("b", "B"),
("B", "c"),
("c", "C"),
- ("C", "
"),
+ ("C", "
.
"),
]
]
)
]
]
)
@@
-111,7
+111,7
@@
class PicroCrafterEngine:
]
)
]
)
- self.acessible_object_to_inventory = dict(
+ self.ac
c
essible_object_to_inventory = dict(
[
(self.token2id[s], self.token2id[t])
for (s, t) in [
[
(self.token2id[s], self.token2id[t])
for (s, t) in [
@@
-120,7
+120,8
@@
class PicroCrafterEngine:
("b", " "),
("B", "b"),
("c", " "),
("b", " "),
("B", "b"),
("c", " "),
- ("C", " "),
+ ("C", "c"),
+ (".", " "),
]
]
)
]
]
)
@@
-208,7
+209,6
@@
class PicroCrafterEngine:
s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
b = a.clone()
b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device)
b = a.clone()
b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]]
-
# position is empty
o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
# or it is the next accessible object
# position is empty
o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long()
# or it is the next accessible object
@@
-219,6
+219,10
@@
class PicroCrafterEngine:
b = (1 - o) * a + o * b
self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
b = (1 - o) * a + o * b
self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"]
+ qq = q
+ q = qq.new_zeros((self.worlds.size(0),) + qq.size()[1:])
+ q[b[:, 0]] = qq
+
nb_hits = self.monster_moves()
alive_before = self.life_level_in_100th > 0
nb_hits = self.monster_moves()
alive_before = self.life_level_in_100th > 0
@@
-243,14
+247,14
@@
class PicroCrafterEngine:
)
inventory = torch.tensor(
[
)
inventory = torch.tensor(
[
- self.acessible_object_to_inventory[s.item()]
+ self.ac
c
essible_object_to_inventory[s.item()]
for s in self.accessible_object
]
)
self.life_level_in_100th = (
self.life_level_in_100th
for s in self.accessible_object
]
)
self.life_level_in_100th = (
self.life_level_in_100th
- * (self.accessible_object != self.token2id["
"]).long()
+ * (self.accessible_object != self.token2id["
.
"]).long()
)
reward[torch.logical_not(alive_before)] = 0
)
reward[torch.logical_not(alive_before)] = 0
@@
-392,16
+396,19
@@
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- ansi_term = False
- # nb_agents, nb_iter, display = 1000, 100, False
+
#
ansi_term = False
+ # nb_agents, nb_iter, display = 1000, 100
0
, False
nb_agents, nb_iter, display = 3, 10000, True
nb_agents, nb_iter, display = 3, 10000, True
-
#
ansi_term = True
+ ansi_term = True
start_time = time.perf_counter()
engine = PicroCrafterEngine(
world_height=27,
world_width=27,
nb_walls=35,
start_time = time.perf_counter()
engine = PicroCrafterEngine(
world_height=27,
world_width=27,
nb_walls=35,
+ # world_height=15,
+ # world_width=15,
+ # nb_walls=0,
view_height=9,
view_width=9,
margin=4,
view_height=9,
view_width=9,
margin=4,
@@
-414,6
+421,7
@@
if __name__ == "__main__":
start_time = time.perf_counter()
start_time = time.perf_counter()
+ stop = 0
for k in range(nb_iter):
action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
rewards, inventories, life_levels = engine.step(
for k in range(nb_iter):
action = torch.randint(engine.nb_actions(), (nb_agents,), device=device)
rewards, inventories, life_levels = engine.step(
@@
-438,7
+446,9
@@
if __name__ == "__main__":
time.sleep(0.25)
if (life_levels > 0).long().sum() == 0:
time.sleep(0.25)
if (life_levels > 0).long().sum() == 0:
- break
+ stop += 1
+ if stop == 2:
+ break
print(
f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"
print(
f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s"