Now also catch ValueError.
[pytorch.git] / minidiffusion.py
index c88765c..066cbbb 100755 (executable)
@@ -66,12 +66,14 @@ def sample_mnist(nb):
     return result
 
 samplers = {
-    'gaussian_mixture': sample_gaussian_mixture,
-    'ramp': sample_ramp,
-    'two_discs': sample_two_discs,
-    'disc_grid': sample_disc_grid,
-    'spiral': sample_spiral,
-    'mnist': sample_mnist,
+    f.__name__.removeprefix('sample_') : f for f in [
+        sample_gaussian_mixture,
+        sample_ramp,
+        sample_two_discs,
+        sample_disc_grid,
+        sample_spiral,
+        sample_mnist,
+    ]
 }
 
 ######################################################################
@@ -313,7 +315,7 @@ if train_input.dim() == 2 and train_input.size(1) == 1:
 
     ax.legend(frameon = False, loc = 2)
 
-    filename = f'diffusion_{args.data}.pdf'
+    filename = f'minidiffusion_{args.data}.pdf'
     print(f'saving {filename}')
     fig.savefig(filename, bbox_inches='tight')
 
@@ -350,7 +352,7 @@ elif train_input.dim() == 2 and train_input.size(1) == 2:
 
     ax.legend(frameon = False, loc = 2)
 
-    filename = f'diffusion_{args.data}.pdf'
+    filename = f'minidiffusion_{args.data}.pdf'
     print(f'saving {filename}')
     fig.savefig(filename, bbox_inches='tight')
 
@@ -375,7 +377,7 @@ elif train_input.dim() == 4:
 
     result = 1 - torch.cat((t, x), 2) / 255
 
-    filename = f'diffusion_{args.data}.png'
+    filename = f'minidiffusion_{args.data}.png'
     print(f'saving {filename}')
     torchvision.utils.save_image(result, filename)