X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=mine_mnist.py;h=7845d813ae20a86ed14f7fb5dd29970841c9397f;hp=06458b5a6e9b827c1373a9a685182dd134a40562;hb=4817708db50a18242ade3ba88971dd4ef0a73004;hpb=fa95c8d2d3663fc3f431beb52e792e5e469db449 diff --git a/mine_mnist.py b/mine_mnist.py index 06458b5..7845d81 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -8,6 +8,13 @@ from torch import nn ###################################################################### +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + +###################################################################### + parser = argparse.ArgumentParser( description = 'An implementation of Mutual Information estimator with a deep model', formatter_class = argparse.ArgumentDefaultsHelpFormatter @@ -27,13 +34,6 @@ parser.add_argument('--mnist_classes', ###################################################################### -if torch.cuda.is_available(): - device = torch.device('cuda') -else: - device = torch.device('cpu') - -###################################################################### - def entropy(target): probas = [] for k in range(target.max() + 1):