X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=agtree2dot.py;fp=agtree2dot.py;h=8cc9e8cb4cad429a3c8b4b338bf16b9619f43d01;hp=8931e366b5f2937ba25330fe4268c4143d387753;hb=d9e6125c82f4e3775b8c868751b1657a6a147f55;hpb=55332826c1d0ec125fc1d2db6644c98b1640d4a2 diff --git a/agtree2dot.py b/agtree2dot.py index 8931e36..8cc9e8c 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -75,7 +75,7 @@ def add_link(node_list, link_list, u, nu, v, nv): ###################################################################### -def build_ag_graph_lists(u, node_labels, node_list, link_list): +def fill_graph_lists(u, node_labels, node_list, link_list): if u is not None and not u in node_list: node = Node(len(node_list) + 1, @@ -84,15 +84,19 @@ def build_ag_graph_lists(u, node_labels, node_list, link_list): node_list[u] = node if isinstance(u, torch.autograd.Variable): - build_ag_graph_lists(u.grad_fn, node_labels, node_list, link_list) + fill_graph_lists(u.grad_fn, node_labels, node_list, link_list) add_link(node_list, link_list, u, 0, u.grad_fn, 0) - else: - if hasattr(u, 'next_functions'): - i = 0 - for v, j in u.next_functions: - build_ag_graph_lists(v, node_labels, node_list, link_list) - add_link(node_list, link_list, u, i, v, j) - i += 1 + + if hasattr(u, 'variable'): + fill_graph_lists(u.variable, node_labels, node_list, link_list) + add_link(node_list, link_list, u, 0, u.variable, 0) + + if hasattr(u, 'next_functions'): + i = 0 + for v, j in u.next_functions: + fill_graph_lists(v, node_labels, node_list, link_list) + add_link(node_list, link_list, u, i, v, j) + i += 1 ###################################################################### @@ -102,14 +106,22 @@ def print_dot(node_list, link_list, out): for n in node_list: node = node_list[n] - out.write( - ' ' + \ - str(node.id) + ' [shape=record,label="{ ' + \ - slot_string(node.max_out, for_input = True) + \ - node.label + \ - slot_string(node.max_in, for_input = False) + \ - ' }"]\n' - ) + if isinstance(n, torch.autograd.Variable): + out.write( + ' ' + \ + str(node.id) + ' [shape=note,label="' + \ + node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \ + '"]\n' + ) + else: + out.write( + ' ' + \ + str(node.id) + ' [shape=record,label="{ ' + \ + slot_string(node.max_out, for_input = True) + \ + node.label + \ + slot_string(node.max_in, for_input = False) + \ + ' }"]\n' + ) for n in link_list: out.write(' ' + \ @@ -124,7 +136,7 @@ def print_dot(node_list, link_list, out): def save_dot(x, node_labels = {}, out = sys.stdout): node_list, link_list = {}, [] - build_ag_graph_lists(x, node_labels, node_list, link_list) + fill_graph_lists(x, node_labels, node_list, link_list) print_dot(node_list, link_list, out) ######################################################################