Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 11 Jul 2023 06:13:35 +0000 (08:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 11 Jul 2023 06:13:35 +0000 (08:13 +0200)
world.py

index a43eff9..d32d545 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-import math, sys
+import math, sys, tqdm
 
 import torch, torchvision
 
@@ -232,10 +232,11 @@ class Normalizer(nn.Module):
     def __init__(self, mu, std):
         super().__init__()
         self.mu = nn.Parameter(mu)
-        self.log_var = nn.Parameter(2*torch.log(std))
+        self.log_var = nn.Parameter(2 * torch.log(std))
 
     def forward(self, x):
-        return (x-self.mu)/torch.exp(self.log_var/2.0)
+        return (x - self.mu) / torch.exp(self.log_var / 2.0)
+
 
 class SignSTE(nn.Module):
     def __init__(self):
@@ -256,8 +257,9 @@ def train_encoder(
     dim_hidden=64,
     block_size=16,
     nb_bits_per_block=10,
-    lr_start=1e-3, lr_end=1e-5,
-    nb_epochs=50,
+    lr_start=1e-3,
+    lr_end=1e-5,
+    nb_epochs=10,
     batch_size=25,
     device=torch.device("cpu"),
 ):
@@ -312,12 +314,19 @@ def train_encoder(
     model.to(device)
 
     for k in range(nb_epochs):
-        lr=math.exp(math.log(lr_start) + math.log(lr_end/lr_start)/(nb_epochs-1)*k)
+        lr = math.exp(
+            math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
+        )
         print(f"lr {lr}")
         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
         acc_loss, nb_samples = 0.0, 0
 
-        for input in train_input.split(batch_size):
+        for input in tqdm.tqdm(
+            train_input.split(batch_size),
+            dynamic_ncols=True,
+            desc="vqae-train",
+            total=train_input.size(0) // batch_size,
+        ):
             output = model(input)
             loss = F.mse_loss(output, input)
             acc_loss += loss.item() * input.size(0)
@@ -341,7 +350,11 @@ if __name__ == "__main__":
     all_frames = []
     nb = 25000
     start_time = time.perf_counter()
-    for n in range(nb):
+    for n in tqdm.tqdm(
+        range(nb),
+        dynamic_ncols=True,
+        desc="world-data",
+    ):
         frames, actions = generate_sequence(nb_steps=31)
         all_frames += frames
     end_time = time.perf_counter()