X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=agtree2dot.py;h=2dcc1d705752bf185248aa228b045a10a88c2c4f;hp=0f787accad9edc1d921d839bde6d81a233d6bfbf;hb=HEAD;hpb=21558202134d649b29fb700481b15963e71b1d1f diff --git a/agtree2dot.py b/agtree2dot.py index 0f787ac..2dcc1d7 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -83,7 +83,7 @@ def fill_graph_lists(u, node_labels, node_list, link_list): re.search('', str(type(u))).group(2)) node_list[u] = node - if isinstance(u, torch.autograd.Variable): + if hasattr(u, 'grad_fn'): fill_graph_lists(u.grad_fn, node_labels, node_list, link_list) add_link(node_list, link_list, u, 0, u.grad_fn, 0) @@ -92,21 +92,15 @@ def fill_graph_lists(u, node_labels, node_list, link_list): add_link(node_list, link_list, u, 0, u.variable, 0) if hasattr(u, 'next_functions'): - i = 0 - for v, j in u.next_functions: + for i, (v, j) in enumerate(u.next_functions): fill_graph_lists(v, node_labels, node_list, link_list) add_link(node_list, link_list, u, i, v, j) - i += 1 ###################################################################### def print_dot(node_list, link_list, out): out.write('digraph{\n') - out.write(' graph [fontname = "helvetica"];\n') - out.write(' node [fontname = "helvetica"];\n') - out.write(' edge [fontname = "helvetica"];\n') - for n in node_list: node = node_list[n]