Update.
[agtree2dot.git] / agtree2dot.py
1 #########################################################################
2 # This program is free software: you can redistribute it and/or modify  #
3 # it under the terms of the version 3 of the GNU General Public License #
4 # as published by the Free Software Foundation.                         #
5 #                                                                       #
6 # This program is distributed in the hope that it will be useful, but   #
7 # WITHOUT ANY WARRANTY; without even the implied warranty of            #
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
9 # General Public License for more details.                              #
10 #                                                                       #
11 # You should have received a copy of the GNU General Public License     #
12 # along with this program. If not, see <http://www.gnu.org/licenses/>.  #
13 #                                                                       #
14 # Written by and Copyright (C) Francois Fleuret                         #
15 # Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
16 #########################################################################
17
18 import torch
19 import sys, re
20
21 ######################################################################
22
23 class Link:
24     def __init__(self, from_node, from_nb, to_node, to_nb):
25         self.from_node = from_node
26         self.from_nb = from_nb
27         self.to_node = to_node
28         self.to_nb = to_nb
29
30 class Node:
31     def __init__(self, id, label):
32         self.id = id
33         self.label = label
34         self.max_in = -1
35         self.max_out = -1
36
37 def slot(node_list, n, k, for_input):
38     if for_input:
39         if node_list[n].max_out > 0:
40             return str(node_list[n].id) + ':input' + str(k)
41         else:
42             return str(node_list[n].id)
43     else:
44         if node_list[n].max_in > 0:
45             return str(node_list[n].id) + ':output' + str(k)
46         else:
47             return str(node_list[n].id)
48
49 def slot_string(k, for_input):
50     result = ''
51
52     if for_input:
53         label = 'input'
54     else:
55         label = 'output'
56
57     if k > 0:
58         if not for_input: result = ' |' + result
59         result +=  ' { <' + label + '0> 0'
60         for j in range(1, k+1):
61             result += " | " + '<' + label + str(j) + '> ' + str(j)
62         result += " } "
63         if for_input: result = result + '| '
64
65     return result
66
67 ######################################################################
68
69 def add_link(node_list, link_list, u, nu, v, nv):
70     link = Link(u, nu, v, nv)
71     link_list.append(link)
72     node_list[u].max_in  = max(node_list[u].max_in,  nu)
73     node_list[v].max_out = max(node_list[u].max_out, nv)
74
75 ######################################################################
76
77 def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
78
79     if not u in node_list:
80         node = Node(len(node_list) + 1,
81                     (u in node_labels and node_labels[u]) or \
82                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
83         node_list[u] = node
84
85         if isinstance(u, torch.autograd.Variable):
86             build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
87             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
88         else:
89             if hasattr(u, 'next_functions'):
90                 i = 0
91                 for v, j in u.next_functions:
92                     build_ag_graph_lists(v, node_labels, out, node_list, link_list)
93                     add_link(node_list, link_list, u, i, v, j)
94                     i += 1
95
96 ######################################################################
97
98 def print_dot(node_list, link_list, out):
99     out.write('digraph{\n')
100
101     for n in node_list:
102         node = node_list[n]
103
104         out.write(
105             '  ' + \
106             str(node.id) + ' [shape=record,label="{ ' + \
107             slot_string(node.max_out, for_input = True) + \
108             node.label + \
109             slot_string(node.max_in, for_input = False) + \
110             ' }"]\n'
111         )
112
113     for n in link_list:
114         out.write('  ' + \
115                   slot(node_list, n.from_node, n.from_nb, for_input = False) + \
116                   ' -> ' + \
117                   slot(node_list, n.to_node, n.to_nb, for_input = True) + \
118                   '\n')
119
120     out.write('}\n')
121
122 ######################################################################
123
124 def save_dot(x, node_labels = {}, out = sys.stdout):
125     node_list = {}
126     link_list = []
127     build_ag_graph_lists(x, node_labels, out, node_list, link_list)
128     print_dot(node_list, link_list, out)
129
130 ######################################################################
131
132 # x = Variable(torch.rand(5))
133 # y = torch.topk(x, 3)
134 # l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
135
136 # save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))