X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=066cbbbe1fa365458ea163bea0e64e1cd2787c74;hb=87d376ff7929347865b199d05d003ab3b168f249;hp=e7be8c1c8651a3cc4ee4a578586e2e4dd3c29bcf;hpb=8d71d2f43cec159c7ca368c1dc4fa76f061d13b7;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index e7be8c1..066cbbb 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -66,12 +66,14 @@ def sample_mnist(nb): return result samplers = { - 'gaussian_mixture': sample_gaussian_mixture, - 'ramp': sample_ramp, - 'two_discs': sample_two_discs, - 'disc_grid': sample_disc_grid, - 'spiral': sample_spiral, - 'mnist': sample_mnist, + f.__name__.removeprefix('sample_') : f for f in [ + sample_gaussian_mixture, + sample_ramp, + sample_two_discs, + sample_disc_grid, + sample_spiral, + sample_mnist, + ] } ######################################################################