X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=agtree2dot.py;h=7643986992968d99998c7524e4266d10b3faaed1;hp=f215f94e26a141c377453eddf5848519d39bbb23;hb=977a75b33305e91933ee4df441d0a249e46f2a31;hpb=63f04303f0320d25d36e6a4f9f535e62cdb139e1 diff --git a/agtree2dot.py b/agtree2dot.py index f215f94..7643986 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -1,4 +1,3 @@ - ######################################################################### # This program is free software: you can redistribute it and/or modify # # it under the terms of the version 3 of the GNU General Public License # @@ -17,92 +16,121 @@ ######################################################################### import torch -import re -import sys +import sys, re + +###################################################################### + +class Link: + def __init__(self, from_node, from_nb, to_node, to_nb): + self.from_node = from_node + self.from_nb = from_nb + self.to_node = to_node + self.to_nb = to_nb + +class Node: + def __init__(self, id, label): + self.id = id + self.label = label + self.max_in = -1 + self.max_out = -1 + +def slot(node_list, n, k, for_input): + if for_input: + if node_list[n].max_out > 0: + return str(node_list[n].id) + ':input' + str(k) + else: + return str(node_list[n].id) + else: + if node_list[n].max_in > 0: + return str(node_list[n].id) + ':output' + str(k) + else: + return str(node_list[n].id) + +def slot_string(k, for_input): + result = '' + + if for_input: + label = 'input' + else: + label = 'output' + + if k > 0: + if not for_input: result = ' |' + result + result += ' { <' + label + '0> 0' + for j in range(1, k+1): + result += " | " + '<' + label + str(j) + '> ' + str(j) + result += " } " + if for_input: result = result + '| ' + + return result -import torch.autograd +###################################################################### + +def add_link(node_list, link_list, u, nu, v, nv): + link = Link(u, nu, v, nv) + link_list.append(link) + node_list[u].max_in = max(node_list[u].max_in, nu) + node_list[v].max_out = max(node_list[u].max_out, nv) ###################################################################### -def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}): +def build_ag_graph_lists(u, node_labels, out, node_list, link_list): + + if not u in node_list: + node = Node(len(node_list) + 1, + (u in node_labels and node_labels[u]) or \ + re.search('', str(type(u))).group(2)) + node_list[u] = node + + if isinstance(u, torch.autograd.Variable): + build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list) + add_link(node_list, link_list, u, 0, u.grad_fn, 0) + else: + if hasattr(u, 'next_functions'): + i = 0 + for v, j in u.next_functions: + build_ag_graph_lists(v, node_labels, out, node_list, link_list) + add_link(node_list, link_list, u, i, v, j) + i += 1 - if isinstance(x, set): +###################################################################### - for y in x: - save_dot_rec(y, node_labels, out, drawn_node_id) +def print_dot(node_list, link_list, out): + out.write('digraph{\n') - else: + for n in node_list: + node = node_list[n] + + out.write( + ' ' + \ + str(node.id) + ' [shape=record,label="{ ' + \ + slot_string(node.max_out, for_input = True) + \ + node.label + \ + slot_string(node.max_in, for_input = False) + \ + ' }"]\n' + ) - if not x in drawn_node_id: - drawn_node_id[x] = len(drawn_node_id) + 1 - - # Draw the node (Variable or Function) if not already - # drawn - - if isinstance(x, torch.autograd.Variable): - name = ((x in node_labels and node_labels[x]) or 'Variable') - # Add the tensor size - name = name + ' [' - for d in range(0, x.data.dim()): - if d > 0: name = name + ', ' - name = name + str(x.data.size(d)) - name = name + ']' - - out.write(' ' + str(drawn_node_id[x]) + - ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n') - - if hasattr(x, 'creator') and x.creator: - y = x.creator - save_dot_rec(y, node_labels, out, drawn_node_id) - # Edge to the creator - out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n') - - elif isinstance(x, torch.autograd.Function): - name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \ - re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1) - - prefix = '' - suffix = '' - - if hasattr(x, 'num_inputs') and x.num_inputs > 1: - prefix = '{ ' - for i in range(0, x.num_inputs): - if i > 0: prefix = prefix + ' | ' - prefix = prefix + ' ' + str(i) - prefix = prefix + ' } | ' - - if hasattr(x, 'num_outputs') and x.num_outputs > 1: - suffix = ' | { ' - for i in range(0, x.num_outputs): - if i > 0: suffix = suffix + ' | ' - suffix = suffix + ' ' + str(i) - suffix = suffix + ' }' - - out.write(' ' + str(drawn_node_id[x]) + \ - ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n') - - else: - - print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).') - exit(1) - - if hasattr(x, 'num_inputs'): - for i in range(0, x.num_inputs): - y = x.previous_functions[i][0] - save_dot_rec(y, node_labels, out, drawn_node_id) - from_str = str(drawn_node_id[y]) - if hasattr(y, 'num_outputs') and y.num_outputs > 1: - from_str = from_str + ':output' + str(x.previous_functions[i][1]) - to_str = str(drawn_node_id[x]) - if x.num_inputs > 1: - to_str = to_str + ':input' + str(i) - out.write(' ' + from_str + ' -> ' + to_str + '\n') + for n in link_list: + out.write(' ' + \ + slot(node_list, n.from_node, n.from_nb, for_input = False) + \ + ' -> ' + \ + slot(node_list, n.to_node, n.to_nb, for_input = True) + \ + '\n') + + out.write('}\n') ###################################################################### def save_dot(x, node_labels = {}, out = sys.stdout): - out.write('digraph {\n') - save_dot_rec(x, node_labels, out, {}) - out.write('}\n') + node_list = {} + link_list = [] + build_ag_graph_lists(x, node_labels, out, node_list, link_list) + print_dot(node_list, link_list, out) ###################################################################### + +# x = Variable(torch.rand(5)) +# y = torch.topk(x, 3) +# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float())) + +# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))