Update.
[agtree2dot.git] / agtree2dot.py
index 8931e36..2dcc1d7 100755 (executable)
@@ -75,7 +75,7 @@ def add_link(node_list, link_list, u, nu, v, nv):
 
 ######################################################################
 
-def build_ag_graph_lists(u, node_labels, node_list, link_list):
+def fill_graph_lists(u, node_labels, node_list, link_list):
 
     if u is not None and not u in node_list:
         node = Node(len(node_list) + 1,
@@ -83,16 +83,18 @@ def build_ag_graph_lists(u, node_labels, node_list, link_list):
                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
         node_list[u] = node
 
-        if isinstance(u, torch.autograd.Variable):
-            build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list)
+        if hasattr(u, 'grad_fn'):
+            fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
-        else:
-            if hasattr(u, 'next_functions'):
-                i = 0
-                for v, j in u.next_functions:
-                    build_ag_graph_lists(v, node_labels, node_list, link_list)
-                    add_link(node_list, link_list, u, i, v, j)
-                    i += 1
+
+        if hasattr(u, 'variable'):
+            fill_graph_lists(u.variable, node_labels, node_list, link_list)
+            add_link(node_list, link_list, u, 0, u.variable, 0)
+
+        if hasattr(u, 'next_functions'):
+            for i, (v, j) in enumerate(u.next_functions):
+                fill_graph_lists(v, node_labels, node_list, link_list)
+                add_link(node_list, link_list, u, i, v, j)
 
 ######################################################################
 
@@ -102,14 +104,22 @@ def print_dot(node_list, link_list, out):
     for n in node_list:
         node = node_list[n]
 
-        out.write(
-            '  ' + \
-            str(node.id) + ' [shape=record,label="{ ' + \
-            slot_string(node.max_out, for_input = True) + \
-            node.label + \
-            slot_string(node.max_in, for_input = False) + \
-            ' }"]\n'
-        )
+        if isinstance(n, torch.autograd.Variable):
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
+                node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
+                '"]\n'
+            )
+        else:
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
+                slot_string(node.max_out, for_input = True) + \
+                node.label + \
+                slot_string(node.max_in, for_input = False) + \
+                ' }"]\n'
+            )
 
     for n in link_list:
         out.write('  ' + \
@@ -124,7 +134,7 @@ def print_dot(node_list, link_list, out):
 
 def save_dot(x, node_labels = {}, out = sys.stdout):
     node_list, link_list = {}, []
-    build_ag_graph_lists(x, node_labels, node_list, link_list)
+    fill_graph_lists(x, node_labels, node_list, link_list)
     print_dot(node_list, link_list, out)
 
 ######################################################################