Update.
[agtree2dot.git] / agtree2dot.py
index 0f787ac..2dcc1d7 100755 (executable)
@@ -83,7 +83,7 @@ def fill_graph_lists(u, node_labels, node_list, link_list):
                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
         node_list[u] = node
 
-        if isinstance(u, torch.autograd.Variable):
+        if hasattr(u, 'grad_fn'):
             fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
 
@@ -92,21 +92,15 @@ def fill_graph_lists(u, node_labels, node_list, link_list):
             add_link(node_list, link_list, u, 0, u.variable, 0)
 
         if hasattr(u, 'next_functions'):
-            i = 0
-            for v, j in u.next_functions:
+            for i, (v, j) in enumerate(u.next_functions):
                 fill_graph_lists(v, node_labels, node_list, link_list)
                 add_link(node_list, link_list, u, i, v, j)
-                i += 1
 
 ######################################################################
 
 def print_dot(node_list, link_list, out):
     out.write('digraph{\n')
 
-    out.write('  graph [fontname = "helvetica"];\n')
-    out.write('  node [fontname = "helvetica"];\n')
-    out.write('  edge [fontname = "helvetica"];\n')
-
     for n in node_list:
         node = node_list[n]