X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=agtree2dot.py;h=2dcc1d705752bf185248aa228b045a10a88c2c4f;hp=f215f94e26a141c377453eddf5848519d39bbb23;hb=HEAD;hpb=63f04303f0320d25d36e6a4f9f535e62cdb139e1 diff --git a/agtree2dot.py b/agtree2dot.py index f215f94..2dcc1d7 100755 --- a/agtree2dot.py +++ b/agtree2dot.py @@ -1,4 +1,3 @@ - ######################################################################### # This program is free software: you can redistribute it and/or modify # # it under the terms of the version 3 of the GNU General Public License # @@ -17,92 +16,125 @@ ######################################################################### import torch -import re -import sys +import sys, re + +###################################################################### + +class Link: + def __init__(self, from_node, from_nb, to_node, to_nb): + self.from_node = from_node + self.from_nb = from_nb + self.to_node = to_node + self.to_nb = to_nb + +class Node: + def __init__(self, id, label): + self.id = id + self.label = label + self.max_in = -1 + self.max_out = -1 + +def slot(node_list, n, k, for_input): + if for_input: + if node_list[n].max_out > 0: + return str(node_list[n].id) + ':input' + str(k) + else: + return str(node_list[n].id) + else: + if node_list[n].max_in > 0: + return str(node_list[n].id) + ':output' + str(k) + else: + return str(node_list[n].id) -import torch.autograd +def slot_string(k, for_input): + result = '' + + if for_input: + label = 'input' + else: + label = 'output' + + if k > 0: + if not for_input: result = ' |' + result + result += ' { <' + label + '0> 0' + for j in range(1, k + 1): + result += " | " + '<' + label + str(j) + '> ' + str(j) + result += " } " + if for_input: result = result + '| ' + + return result ###################################################################### -def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}): +def add_link(node_list, link_list, u, nu, v, nv): + if u is not None and v is not None: + link = Link(u, nu, v, nv) + link_list.append(link) + node_list[u].max_in = max(node_list[u].max_in, nu) + node_list[v].max_out = max(node_list[v].max_out, nv) - if isinstance(x, set): +###################################################################### - for y in x: - save_dot_rec(y, node_labels, out, drawn_node_id) +def fill_graph_lists(u, node_labels, node_list, link_list): - else: + if u is not None and not u in node_list: + node = Node(len(node_list) + 1, + (u in node_labels and node_labels[u]) or \ + re.search('', str(type(u))).group(2)) + node_list[u] = node + + if hasattr(u, 'grad_fn'): + fill_graph_lists(u.grad_fn, node_labels, node_list, link_list) + add_link(node_list, link_list, u, 0, u.grad_fn, 0) + + if hasattr(u, 'variable'): + fill_graph_lists(u.variable, node_labels, node_list, link_list) + add_link(node_list, link_list, u, 0, u.variable, 0) - if not x in drawn_node_id: - drawn_node_id[x] = len(drawn_node_id) + 1 - - # Draw the node (Variable or Function) if not already - # drawn - - if isinstance(x, torch.autograd.Variable): - name = ((x in node_labels and node_labels[x]) or 'Variable') - # Add the tensor size - name = name + ' [' - for d in range(0, x.data.dim()): - if d > 0: name = name + ', ' - name = name + str(x.data.size(d)) - name = name + ']' - - out.write(' ' + str(drawn_node_id[x]) + - ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n') - - if hasattr(x, 'creator') and x.creator: - y = x.creator - save_dot_rec(y, node_labels, out, drawn_node_id) - # Edge to the creator - out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n') - - elif isinstance(x, torch.autograd.Function): - name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \ - re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1) - - prefix = '' - suffix = '' - - if hasattr(x, 'num_inputs') and x.num_inputs > 1: - prefix = '{ ' - for i in range(0, x.num_inputs): - if i > 0: prefix = prefix + ' | ' - prefix = prefix + ' ' + str(i) - prefix = prefix + ' } | ' - - if hasattr(x, 'num_outputs') and x.num_outputs > 1: - suffix = ' | { ' - for i in range(0, x.num_outputs): - if i > 0: suffix = suffix + ' | ' - suffix = suffix + ' ' + str(i) - suffix = suffix + ' }' - - out.write(' ' + str(drawn_node_id[x]) + \ - ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n') - - else: - - print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).') - exit(1) - - if hasattr(x, 'num_inputs'): - for i in range(0, x.num_inputs): - y = x.previous_functions[i][0] - save_dot_rec(y, node_labels, out, drawn_node_id) - from_str = str(drawn_node_id[y]) - if hasattr(y, 'num_outputs') and y.num_outputs > 1: - from_str = from_str + ':output' + str(x.previous_functions[i][1]) - to_str = str(drawn_node_id[x]) - if x.num_inputs > 1: - to_str = to_str + ':input' + str(i) - out.write(' ' + from_str + ' -> ' + to_str + '\n') + if hasattr(u, 'next_functions'): + for i, (v, j) in enumerate(u.next_functions): + fill_graph_lists(v, node_labels, node_list, link_list) + add_link(node_list, link_list, u, i, v, j) ###################################################################### -def save_dot(x, node_labels = {}, out = sys.stdout): - out.write('digraph {\n') - save_dot_rec(x, node_labels, out, {}) +def print_dot(node_list, link_list, out): + out.write('digraph{\n') + + for n in node_list: + node = node_list[n] + + if isinstance(n, torch.autograd.Variable): + out.write( + ' ' + \ + str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \ + node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \ + '"]\n' + ) + else: + out.write( + ' ' + \ + str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \ + slot_string(node.max_out, for_input = True) + \ + node.label + \ + slot_string(node.max_in, for_input = False) + \ + ' }"]\n' + ) + + for n in link_list: + out.write(' ' + \ + slot(node_list, n.from_node, n.from_nb, for_input = False) + \ + ' -> ' + \ + slot(node_list, n.to_node, n.to_nb, for_input = True) + \ + '\n') + out.write('}\n') ###################################################################### + +def save_dot(x, node_labels = {}, out = sys.stdout): + node_list, link_list = {}, [] + fill_graph_lists(x, node_labels, node_list, link_list) + print_dot(node_list, link_list, out) + +######################################################################