Update.
[pytorch.git] / minidiffusion.py
index e7be8c1..066cbbb 100755 (executable)
@@ -66,12 +66,14 @@ def sample_mnist(nb):
     return result
 
 samplers = {
-    'gaussian_mixture': sample_gaussian_mixture,
-    'ramp': sample_ramp,
-    'two_discs': sample_two_discs,
-    'disc_grid': sample_disc_grid,
-    'spiral': sample_spiral,
-    'mnist': sample_mnist,
+    f.__name__.removeprefix('sample_') : f for f in [
+        sample_gaussian_mixture,
+        sample_ramp,
+        sample_two_discs,
+        sample_disc_grid,
+        sample_spiral,
+        sample_mnist,
+    ]
 }
 
 ######################################################################