projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c03e968
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 11 Jul 2023 06:13:35 +0000
(08:13 +0200)
committer
François Fleuret
<francois@fleuret.org>
Tue, 11 Jul 2023 06:13:35 +0000
(08:13 +0200)
world.py
patch
|
blob
|
history
diff --git
a/world.py
b/world.py
index
a43eff9
..
d32d545
100755
(executable)
--- a/
world.py
+++ b/
world.py
@@
-1,6
+1,6
@@
#!/usr/bin/env python
#!/usr/bin/env python
-import math, sys
+import math, sys
, tqdm
import torch, torchvision
import torch, torchvision
@@
-232,10
+232,11
@@
class Normalizer(nn.Module):
def __init__(self, mu, std):
super().__init__()
self.mu = nn.Parameter(mu)
def __init__(self, mu, std):
super().__init__()
self.mu = nn.Parameter(mu)
- self.log_var = nn.Parameter(2
*
torch.log(std))
+ self.log_var = nn.Parameter(2
*
torch.log(std))
def forward(self, x):
def forward(self, x):
- return (x-self.mu)/torch.exp(self.log_var/2.0)
+ return (x - self.mu) / torch.exp(self.log_var / 2.0)
+
class SignSTE(nn.Module):
def __init__(self):
class SignSTE(nn.Module):
def __init__(self):
@@
-256,8
+257,9
@@
def train_encoder(
dim_hidden=64,
block_size=16,
nb_bits_per_block=10,
dim_hidden=64,
block_size=16,
nb_bits_per_block=10,
- lr_start=1e-3, lr_end=1e-5,
- nb_epochs=50,
+ lr_start=1e-3,
+ lr_end=1e-5,
+ nb_epochs=10,
batch_size=25,
device=torch.device("cpu"),
):
batch_size=25,
device=torch.device("cpu"),
):
@@
-312,12
+314,19
@@
def train_encoder(
model.to(device)
for k in range(nb_epochs):
model.to(device)
for k in range(nb_epochs):
- lr=math.exp(math.log(lr_start) + math.log(lr_end/lr_start)/(nb_epochs-1)*k)
+ lr = math.exp(
+ math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
+ )
print(f"lr {lr}")
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
acc_loss, nb_samples = 0.0, 0
print(f"lr {lr}")
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
acc_loss, nb_samples = 0.0, 0
- for input in train_input.split(batch_size):
+ for input in tqdm.tqdm(
+ train_input.split(batch_size),
+ dynamic_ncols=True,
+ desc="vqae-train",
+ total=train_input.size(0) // batch_size,
+ ):
output = model(input)
loss = F.mse_loss(output, input)
acc_loss += loss.item() * input.size(0)
output = model(input)
loss = F.mse_loss(output, input)
acc_loss += loss.item() * input.size(0)
@@
-341,7
+350,11
@@
if __name__ == "__main__":
all_frames = []
nb = 25000
start_time = time.perf_counter()
all_frames = []
nb = 25000
start_time = time.perf_counter()
- for n in range(nb):
+ for n in tqdm.tqdm(
+ range(nb),
+ dynamic_ncols=True,
+ desc="world-data",
+ ):
frames, actions = generate_sequence(nb_steps=31)
all_frames += frames
end_time = time.perf_counter()
frames, actions = generate_sequence(nb_steps=31)
all_frames += frames
end_time = time.perf_counter()