projects
/
agtree2dot.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[agtree2dot.git]
/
agtree2dot.py
diff --git
a/agtree2dot.py
b/agtree2dot.py
index
0f787ac
..
2dcc1d7
100755
(executable)
--- a/
agtree2dot.py
+++ b/
agtree2dot.py
@@
-83,7
+83,7
@@
def fill_graph_lists(u, node_labels, node_list, link_list):
re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
node_list[u] = node
re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
node_list[u] = node
- if
isinstance(u, torch.autograd.Variable
):
+ if
hasattr(u, 'grad_fn'
):
fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
add_link(node_list, link_list, u, 0, u.grad_fn, 0)
fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
add_link(node_list, link_list, u, 0, u.grad_fn, 0)
@@
-92,21
+92,15
@@
def fill_graph_lists(u, node_labels, node_list, link_list):
add_link(node_list, link_list, u, 0, u.variable, 0)
if hasattr(u, 'next_functions'):
add_link(node_list, link_list, u, 0, u.variable, 0)
if hasattr(u, 'next_functions'):
- i = 0
- for v, j in u.next_functions:
+ for i, (v, j) in enumerate(u.next_functions):
fill_graph_lists(v, node_labels, node_list, link_list)
add_link(node_list, link_list, u, i, v, j)
fill_graph_lists(v, node_labels, node_list, link_list)
add_link(node_list, link_list, u, i, v, j)
- i += 1
######################################################################
def print_dot(node_list, link_list, out):
out.write('digraph{\n')
######################################################################
def print_dot(node_list, link_list, out):
out.write('digraph{\n')
- out.write(' graph [fontname = "helvetica"];\n')
- out.write(' node [fontname = "helvetica"];\n')
- out.write(' edge [fontname = "helvetica"];\n')
-
for n in node_list:
node = node_list[n]
for n in node_list:
node = node_list[n]