re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
node_list[u] = node
- if isinstance(u, torch.autograd.Variable):
+ if hasattr(u, 'grad_fn'):
fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
add_link(node_list, link_list, u, 0, u.grad_fn, 0)
def print_dot(node_list, link_list, out):
out.write('digraph{\n')
- out.write(' graph [fontname = "helvetica"];\n')
- out.write(' node [fontname = "helvetica"];\n')
- out.write(' edge [fontname = "helvetica"];\n')
-
for n in node_list:
node = node_list[n]