Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 12:25:45 +0000 (14:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jul 2023 12:25:45 +0000 (14:25 +0200)
world.py

index da7de75..64c7434 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -169,7 +169,7 @@ def train_encoder(
             train_loss = F.cross_entropy(output, input)
 
             if lambda_entropy > 0:
-                loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+                train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
 
             acc_train_loss += train_loss.item() * input.size(0)
 
@@ -439,26 +439,21 @@ if __name__ == "__main__":
         frame2seq,
         seq2frame,
     ) = create_data_and_processors(
-        # 10000, 1000,
-        100,
-        100,
-        nb_epochs=2,
+        25000, 1000,
+        nb_epochs=10,
         mode="first_last",
         nb_steps=20,
     )
 
-    input = test_input[:64]
+    input = test_input[:256]
 
     seq = frame2seq(input)
-
-    print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}")
-
     output = seq2frame(seq)
 
     torchvision.utils.save_image(
-        input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
+        input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
     )
 
     torchvision.utils.save_image(
-        output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
+        output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16
     )