Update.
[pytorch.git] / tiny_vae.py
index fa09831..4d11c7f 100755 (executable)
@@ -175,12 +175,12 @@ def save_images(model, prefix=""):
     def save_image(x, filename):
         x = x * train_std + train_mu
         x = x.clamp(min=0, max=255) / 255
-        torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+        torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0)
         log_string(f"wrote {filename}")
 
     # Save a bunch of train images
 
-    x = train_input[:256]
+    x = train_input[:36]
     save_image(x, f"{prefix}train_input.png")
 
     # Save the same images after encoding / decoding
@@ -194,7 +194,7 @@ def save_images(model, prefix=""):
 
     # Save a bunch of test images
 
-    x = test_input[:256]
+    x = test_input[:36]
     save_image(x, f"{prefix}input.png")
 
     # Save the same images after encoding / decoding