X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=agtree2dot.py;h=2dcc1d705752bf185248aa228b045a10a88c2c4f;hp=7643986992968d99998c7524e4266d10b3faaed1;hb=HEAD;hpb=977a75b33305e91933ee4df441d0a249e46f2a31 diff --git a/agtree2dot.py b/agtree2dot.py index 7643986..2dcc1d7 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -57,7 +57,7 @@ def slot_string(k, for_input): if k > 0: if not for_input: result = ' |' + result result += ' { <' + label + '0> 0' - for j in range(1, k+1): + for j in range(1, k + 1): result += " | " + '<' + label + str(j) + '> ' + str(j) result += " } " if for_input: result = result + '| ' @@ -67,31 +67,34 @@ def slot_string(k, for_input): ###################################################################### def add_link(node_list, link_list, u, nu, v, nv): - link = Link(u, nu, v, nv) - link_list.append(link) - node_list[u].max_in = max(node_list[u].max_in, nu) - node_list[v].max_out = max(node_list[u].max_out, nv) + if u is not None and v is not None: + link = Link(u, nu, v, nv) + link_list.append(link) + node_list[u].max_in = max(node_list[u].max_in, nu) + node_list[v].max_out = max(node_list[v].max_out, nv) ###################################################################### -def build_ag_graph_lists(u, node_labels, out, node_list, link_list): +def fill_graph_lists(u, node_labels, node_list, link_list): - if not u in node_list: + if u is not None and not u in node_list: node = Node(len(node_list) + 1, (u in node_labels and node_labels[u]) or \ re.search('', str(type(u))).group(2)) node_list[u] = node - if isinstance(u, torch.autograd.Variable): - build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list) + if hasattr(u, 'grad_fn'): + fill_graph_lists(u.grad_fn, node_labels, node_list, link_list) add_link(node_list, link_list, u, 0, u.grad_fn, 0) - else: - if hasattr(u, 'next_functions'): - i = 0 - for v, j in u.next_functions: - build_ag_graph_lists(v, node_labels, out, node_list, link_list) - add_link(node_list, link_list, u, i, v, j) - i += 1 + + if hasattr(u, 'variable'): + fill_graph_lists(u.variable, node_labels, node_list, link_list) + add_link(node_list, link_list, u, 0, u.variable, 0) + + if hasattr(u, 'next_functions'): + for i, (v, j) in enumerate(u.next_functions): + fill_graph_lists(v, node_labels, node_list, link_list) + add_link(node_list, link_list, u, i, v, j) ###################################################################### @@ -101,14 +104,22 @@ def print_dot(node_list, link_list, out): for n in node_list: node = node_list[n] - out.write( - ' ' + \ - str(node.id) + ' [shape=record,label="{ ' + \ - slot_string(node.max_out, for_input = True) + \ - node.label + \ - slot_string(node.max_in, for_input = False) + \ - ' }"]\n' - ) + if isinstance(n, torch.autograd.Variable): + out.write( + ' ' + \ + str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \ + node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \ + '"]\n' + ) + else: + out.write( + ' ' + \ + str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \ + slot_string(node.max_out, for_input = True) + \ + node.label + \ + slot_string(node.max_in, for_input = False) + \ + ' }"]\n' + ) for n in link_list: out.write(' ' + \ @@ -122,15 +133,8 @@ def print_dot(node_list, link_list, out): ###################################################################### def save_dot(x, node_labels = {}, out = sys.stdout): - node_list = {} - link_list = [] - build_ag_graph_lists(x, node_labels, out, node_list, link_list) + node_list, link_list = {}, [] + fill_graph_lists(x, node_labels, node_list, link_list) print_dot(node_list, link_list, out) ###################################################################### - -# x = Variable(torch.rand(5)) -# y = torch.topk(x, 3) -# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float())) - -# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))