Update.
[agtree2dot.git] / agtree2dot.py
index 8cc9e8c..2dcc1d7 100755 (executable)
@@ -83,7 +83,7 @@ def fill_graph_lists(u, node_labels, node_list, link_list):
                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
         node_list[u] = node
 
-        if isinstance(u, torch.autograd.Variable):
+        if hasattr(u, 'grad_fn'):
             fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
 
@@ -92,11 +92,9 @@ def fill_graph_lists(u, node_labels, node_list, link_list):
             add_link(node_list, link_list, u, 0, u.variable, 0)
 
         if hasattr(u, 'next_functions'):
-            i = 0
-            for v, j in u.next_functions:
+            for i, (v, j) in enumerate(u.next_functions):
                 fill_graph_lists(v, node_labels, node_list, link_list)
                 add_link(node_list, link_list, u, i, v, j)
-                i += 1
 
 ######################################################################
 
@@ -109,14 +107,14 @@ def print_dot(node_list, link_list, out):
         if isinstance(n, torch.autograd.Variable):
             out.write(
                 '  ' + \
-                str(node.id) + ' [shape=note,label="' + \
+                str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
                 node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
                 '"]\n'
             )
         else:
             out.write(
                 '  ' + \
-                str(node.id) + ' [shape=record,label="{ ' + \
+                str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
                 slot_string(node.max_out, for_input = True) + \
                 node.label + \
                 slot_string(node.max_in, for_input = False) + \