re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
node_list[u] = node
- if isinstance(u, torch.autograd.Variable):
+ if hasattr(u, 'grad_fn'):
fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
add_link(node_list, link_list, u, 0, u.grad_fn, 0)
add_link(node_list, link_list, u, 0, u.variable, 0)
if hasattr(u, 'next_functions'):
- i = 0
- for v, j in u.next_functions:
+ for i, (v, j) in enumerate(u.next_functions):
fill_graph_lists(v, node_labels, node_list, link_list)
add_link(node_list, link_list, u, i, v, j)
- i += 1
######################################################################
if isinstance(n, torch.autograd.Variable):
out.write(
' ' + \
- str(node.id) + ' [shape=note,label="' + \
+ str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
'"]\n'
)
else:
out.write(
' ' + \
- str(node.id) + ' [shape=record,label="{ ' + \
+ str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
slot_string(node.max_out, for_input = True) + \
node.label + \
slot_string(node.max_in, for_input = False) + \