4eef05a4fca1cd967af795b817a3827fd0cdd7d8
[agtree2dot.git] / agtree2dot.py
1 #########################################################################
2 # This program is free software: you can redistribute it and/or modify  #
3 # it under the terms of the version 3 of the GNU General Public License #
4 # as published by the Free Software Foundation.                         #
5 #                                                                       #
6 # This program is distributed in the hope that it will be useful, but   #
7 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
9 # General Public License for more details.                              #
10 #                                                                       #
11 # You should have received a copy of the GNU General Public License     #
12 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
13 #                                                                       #
14 # Written by and Copyright (C) Francois Fleuret                         #
15 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
16 #########################################################################
17
18 import torch
19 import sys, re
20
21 ######################################################################
22
23 class Link:
24     def __init__(self, from_node, from_nb, to_node, to_nb):
25         self.from_node = from_node
26         self.from_nb = from_nb
27         self.to_node = to_node
28         self.to_nb = to_nb
29
30 class Node:
31     def __init__(self, id, label):
32         self.id = id
33         self.label = label
34         self.max_in = -1
35         self.max_out = -1
36
37 def slot(node_list, n, k, for_input):
38     if for_input:
39         if node_list[n].max_out > 0:
40             return str(node_list[n].id) + ':input' + str(k)
41         else:
42             return str(node_list[n].id)
43     else:
44         if node_list[n].max_in > 0:
45             return str(node_list[n].id) + ':output' + str(k)
46         else:
47             return str(node_list[n].id)
48
49 def slot_string(k, for_input):
50     result = ''
51
52     if for_input:
53         label = 'input'
54     else:
55         label = 'output'
56
57     if k > 0:
58         if not for_input: result = ' |' + result
59         result +=  ' { <' + label + '0> 0'
60         for j in range(1, k + 1):
61             result += " | " + '<' + label + str(j) + '> ' + str(j)
62         result += " } "
63         if for_input: result = result + '| '
64
65     return result
66
67 ######################################################################
68
69 def add_link(node_list, link_list, u, nu, v, nv):
70     if u is not None and v is not None:
71         link = Link(u, nu, v, nv)
72         link_list.append(link)
73         node_list[u].max_in  = max(node_list[u].max_in,  nu)
74         node_list[v].max_out = max(node_list[v].max_out, nv)
75
76 ######################################################################
77
78 def fill_graph_lists(u, node_labels, node_list, link_list):
79
80     if u is not None and not u in node_list:
81         node = Node(len(node_list) + 1,
82                     (u in node_labels and node_labels[u]) or \
83                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
84         node_list[u] = node
85
86         if hasattr(u, 'grad_fn'):
87             fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
88             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
89
90         if hasattr(u, 'variable'):
91             fill_graph_lists(u.variable, node_labels, node_list, link_list)
92             add_link(node_list, link_list, u, 0, u.variable, 0)
93
94         if hasattr(u, 'next_functions'):
95             i = 0
96             for v, j in u.next_functions:
97                 fill_graph_lists(v, node_labels, node_list, link_list)
98                 add_link(node_list, link_list, u, i, v, j)
99                 i += 1
100
101 ######################################################################
102
103 def print_dot(node_list, link_list, out):
104     out.write('digraph{\n')
105
106     out.write('  graph [fontname = "helvetica"];\n')
107     out.write('  node [fontname = "helvetica"];\n')
108     out.write('  edge [fontname = "helvetica"];\n')
109
110     for n in node_list:
111         node = node_list[n]
112
113         if isinstance(n, torch.autograd.Variable):
114             out.write(
115                 '  ' + \
116                 str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
117                 node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
118                 '"]\n'
119             )
120         else:
121             out.write(
122                 '  ' + \
123                 str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
124                 slot_string(node.max_out, for_input = True) + \
125                 node.label + \
126                 slot_string(node.max_in, for_input = False) + \
127                 ' }"]\n'
128             )
129
130     for n in link_list:
131         out.write('  ' + \
132                   slot(node_list, n.from_node, n.from_nb, for_input = False) + \
133                   ' -> ' + \
134                   slot(node_list, n.to_node, n.to_nb, for_input = True) + \
135                   '\n')
136
137     out.write('}\n')
138
139 ######################################################################
140
141 def save_dot(x, node_labels = {}, out = sys.stdout):
142     node_list, link_list = {}, []
143     fill_graph_lists(x, node_labels, node_list, link_list)
144     print_dot(node_list, link_list, out)
145
146 ######################################################################