From: Francois Fleuret Date: Mon, 3 Dec 2018 16:58:07 +0000 (-0500) Subject: OCD update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=4817708db50a18242ade3ba88971dd4ef0a73004 OCD update. --- diff --git a/mine_mnist.py b/mine_mnist.py index 06458b5..7845d81 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -8,6 +8,13 @@ from torch import nn ###################################################################### +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + +###################################################################### + parser = argparse.ArgumentParser( description = 'An implementation of Mutual Information estimator with a deep model', formatter_class = argparse.ArgumentDefaultsHelpFormatter @@ -27,13 +34,6 @@ parser.add_argument('--mnist_classes', ###################################################################### -if torch.cuda.is_available(): - device = torch.device('cuda') -else: - device = torch.device('cpu') - -###################################################################### - def entropy(target): probas = [] for k in range(target.max() + 1):