Style variable nodes differently, shows the tensor size, invoke the dot command in...
authorFrancois Fleuret <francois@fleuret.org>
Mon, 21 Aug 2017 06:19:09 +0000 (08:19 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 21 Aug 2017 06:19:09 +0000 (08:19 +0200)
README.md
agtree2dot.py
mlp.pdf
mlp.py

index 6b996b4..452aa9c 100644 (file)
--- a/README.md
+++ b/README.md
@@ -10,14 +10,17 @@ from a [pytorch](http://pytorch.org) autograd graph.
 
 ### agtree2dot.save_dot(variable, variable_labels, result_file) ###
 
-Saves into `result_file` a dot file corresponding to the autograd graph for `variable`, which can be either a single `Variable` or a set of `Variable`s. The dictionary `variable_labels` associates strings to some variables, which will be used in the resulting graph.
+Saves into `result_file` a dot file corresponding to the autograd
+graph for the `Variable` `variable`. The dictionary `variable_labels`
+associates strings to some variables, which will be used in the
+resulting graph.
 
 ## Example ##
 
-A typical use would be:
+A typical use is provided in [mlp.py](https://fleuret.org/git-extract/agtree2dot/mlp.py):
 
 ```python
-import torch
+import subprocess
 
 from torch import nn
 from torch.nn import functional as fn
@@ -47,15 +50,28 @@ criterion = nn.MSELoss()
 loss = criterion(output, target)
 
 agtree2dot.save_dot(loss,
-                    { input: 'input', target: 'target', loss: 'loss' },
+                    {
+                        input: 'input',
+                        target: 'target',
+                        loss: 'loss',
+                        mlp.fc1.weight: 'weight1',
+                        mlp.fc1.bias: 'bias1',
+                        mlp.fc2.weight: 'weight2',
+                        mlp.fc2.bias: 'bias2',
+                    },
                     open('./mlp.dot', 'w'))
-```
 
-which would generate a file mlp.dot, which can then be translated to
-pdf using the [Graphviz tools](http://www.graphviz.org/)
+print('Generated mlp.dot')
 
-```
-dot mlp.dot -Lg -T pdf -o mlp.pdf
+try:
+    subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ])
+except subprocess.CalledProcessError:
+    print('Calling the dot command failed. Is Graphviz installed?')
+    sys.exit(1)
+
+print('Generated mlp.pdf')
 ```
 
-to produce [mlp.pdf.](https://fleuret.org/git-extract/agtree2dot/mlp.pdf)
+which would generate a file mlp.dot and try to generate
+[mlp.pdf](https://fleuret.org/git-extract/agtree2dot/mlp.pdf) from it
+with [Graphviz tools.](http://www.graphviz.org/)
index 8931e36..8cc9e8c 100755 (executable)
@@ -75,7 +75,7 @@ def add_link(node_list, link_list, u, nu, v, nv):
 
 ######################################################################
 
-def build_ag_graph_lists(u, node_labels, node_list, link_list):
+def fill_graph_lists(u, node_labels, node_list, link_list):
 
     if u is not None and not u in node_list:
         node = Node(len(node_list) + 1,
@@ -84,15 +84,19 @@ def build_ag_graph_lists(u, node_labels, node_list, link_list):
         node_list[u] = node
 
         if isinstance(u, torch.autograd.Variable):
-            build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list)
+            fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
-        else:
-            if hasattr(u, 'next_functions'):
-                i = 0
-                for v, j in u.next_functions:
-                    build_ag_graph_lists(v, node_labels, node_list, link_list)
-                    add_link(node_list, link_list, u, i, v, j)
-                    i += 1
+
+        if hasattr(u, 'variable'):
+            fill_graph_lists(u.variable, node_labels, node_list, link_list)
+            add_link(node_list, link_list, u, 0, u.variable, 0)
+
+        if hasattr(u, 'next_functions'):
+            i = 0
+            for v, j in u.next_functions:
+                fill_graph_lists(v, node_labels, node_list, link_list)
+                add_link(node_list, link_list, u, i, v, j)
+                i += 1
 
 ######################################################################
 
@@ -102,14 +106,22 @@ def print_dot(node_list, link_list, out):
     for n in node_list:
         node = node_list[n]
 
-        out.write(
-            '  ' + \
-            str(node.id) + ' [shape=record,label="{ ' + \
-            slot_string(node.max_out, for_input = True) + \
-            node.label + \
-            slot_string(node.max_in, for_input = False) + \
-            ' }"]\n'
-        )
+        if isinstance(n, torch.autograd.Variable):
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=note,label="' + \
+                node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
+                '"]\n'
+            )
+        else:
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=record,label="{ ' + \
+                slot_string(node.max_out, for_input = True) + \
+                node.label + \
+                slot_string(node.max_in, for_input = False) + \
+                ' }"]\n'
+            )
 
     for n in link_list:
         out.write('  ' + \
@@ -124,7 +136,7 @@ def print_dot(node_list, link_list, out):
 
 def save_dot(x, node_labels = {}, out = sys.stdout):
     node_list, link_list = {}, []
-    build_ag_graph_lists(x, node_labels, node_list, link_list)
+    fill_graph_lists(x, node_labels, node_list, link_list)
     print_dot(node_list, link_list, out)
 
 ######################################################################
diff --git a/mlp.pdf b/mlp.pdf
index 0f41f81..52dea89 100644 (file)
Binary files a/mlp.pdf and b/mlp.pdf differ
diff --git a/mlp.py b/mlp.py
index 8497848..3c5f026 100755 (executable)
--- a/mlp.py
+++ b/mlp.py
@@ -17,6 +17,8 @@
 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
 #########################################################################
 
+import subprocess
+
 from torch import nn
 from torch.nn import functional as fn
 from torch import Tensor
@@ -45,8 +47,23 @@ criterion = nn.MSELoss()
 loss = criterion(output, target)
 
 agtree2dot.save_dot(loss,
-                    { input: 'input', target: 'target', loss: 'loss' },
+                    {
+                        input: 'input',
+                        target: 'target',
+                        loss: 'loss',
+                        mlp.fc1.weight: 'weight1',
+                        mlp.fc1.bias: 'bias1',
+                        mlp.fc2.weight: 'weight2',
+                        mlp.fc2.bias: 'bias2',
+                    },
                     open('./mlp.dot', 'w'))
 
-print('Generated mlp.dot. You can convert it to pdf with')
-print('> dot mlp.dot -Lg -T pdf -o mlp.pdf')
+print('Generated mlp.dot')
+
+try:
+    subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ])
+except subprocess.CalledProcessError:
+    print('Calling the dot command failed. Is Graphviz installed?')
+    sys.exit(1)
+
+print('Generated mlp.pdf')