X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=tiny_vae.py;fp=tiny_vae.py;h=4d11c7f41c80d42903893cb2b756bb9392c8e79c;hp=fa09831c22d0fc8393ed15a6f3420a1acaef90eb;hb=bbe5b7ddb723696fb5388be950af252cb95eb5fb;hpb=0fdaaceb231d31d53d0c623848b8ac56964bedb5 diff --git a/tiny_vae.py b/tiny_vae.py index fa09831..4d11c7f 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -175,12 +175,12 @@ def save_images(model, prefix=""): def save_image(x, filename): x = x * train_std + train_mu x = x.clamp(min=0, max=255) / 255 - torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) + torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0) log_string(f"wrote {filename}") # Save a bunch of train images - x = train_input[:256] + x = train_input[:36] save_image(x, f"{prefix}train_input.png") # Save the same images after encoding / decoding @@ -194,7 +194,7 @@ def save_images(model, prefix=""): # Save a bunch of test images - x = test_input[:256] + x = test_input[:36] save_image(x, f"{prefix}input.png") # Save the same images after encoding / decoding