projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
OCD update.
[pytorch.git]
/
mine_mnist.py
diff --git
a/mine_mnist.py
b/mine_mnist.py
index
06458b5
..
7845d81
100755
(executable)
--- a/
mine_mnist.py
+++ b/
mine_mnist.py
@@
-8,6
+8,13
@@
from torch import nn
######################################################################
######################################################################
+if torch.cuda.is_available():
+ device = torch.device('cuda')
+else:
+ device = torch.device('cpu')
+
+######################################################################
+
parser = argparse.ArgumentParser(
description = 'An implementation of Mutual Information estimator with a deep model',
formatter_class = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(
description = 'An implementation of Mutual Information estimator with a deep model',
formatter_class = argparse.ArgumentDefaultsHelpFormatter
@@
-27,13
+34,6
@@
parser.add_argument('--mnist_classes',
######################################################################
######################################################################
-if torch.cuda.is_available():
- device = torch.device('cuda')
-else:
- device = torch.device('cpu')
-
-######################################################################
-
def entropy(target):
probas = []
for k in range(target.max() + 1):
def entropy(target):
probas = []
for k in range(target.max() + 1):