From: François Fleuret Date: Mon, 17 Jul 2023 12:25:45 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e3a8032a070175ece08fc79c77312d5f2f59150e;p=picoclvr.git Update. --- diff --git a/world.py b/world.py index da7de75..64c7434 100755 --- a/world.py +++ b/world.py @@ -169,7 +169,7 @@ def train_encoder( train_loss = F.cross_entropy(output, input) if lambda_entropy > 0: - loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5) + train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5) acc_train_loss += train_loss.item() * input.size(0) @@ -439,26 +439,21 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - # 10000, 1000, - 100, - 100, - nb_epochs=2, + 25000, 1000, + nb_epochs=10, mode="first_last", nb_steps=20, ) - input = test_input[:64] + input = test_input[:256] seq = frame2seq(input) - - print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}") - output = seq2frame(seq) torchvision.utils.save_image( - input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8 + input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16 ) torchvision.utils.save_image( - output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8 + output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16 )