X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=README.md;h=69806721794f3085c45668ec9d828f369addf76a;hp=452aa9cba717f36e8a1c9a11c0a09918bff12f2f;hb=HEAD;hpb=d9e6125c82f4e3775b8c868751b1657a6a147f55 diff --git a/README.md b/README.md index 452aa9c..6980672 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This package provides a function that generates a [dot file](https://en.wikipedia.org/wiki/DOT_(graph_description_language)) -from a [pytorch](http://pytorch.org) autograd graph. +from a [PyTorch](http://pytorch.org) autograd graph. # Usage # @@ -20,12 +20,9 @@ resulting graph. A typical use is provided in [mlp.py](https://fleuret.org/git-extract/agtree2dot/mlp.py): ```python -import subprocess - from torch import nn from torch.nn import functional as fn from torch import Tensor -from torch.autograd import Variable from torch.nn import Module import agtree2dot @@ -43,8 +40,8 @@ class MLP(Module): return x mlp = MLP(10, 20, 1) -input = Variable(Tensor(100, 10).normal_()) -target = Variable(Tensor(100).normal_()) +input = Tensor(100, 10).normal_() +target = Tensor(100, 1).normal_() output = mlp(input) criterion = nn.MSELoss() loss = criterion(output, target) @@ -64,7 +61,8 @@ agtree2dot.save_dot(loss, print('Generated mlp.dot') try: - subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ]) + subprocess.check_call(['dot', 'mlp.dot', '-Lg', '-T', 'pdf', '-o', 'mlp.pdf' ]) + except subprocess.CalledProcessError: print('Calling the dot command failed. Is Graphviz installed?') sys.exit(1)