2d89af5cc8130d90e9e8087e549b45f66b69aefd
[agtree2dot.git] / agtree2dot.py
1 #!/usr/bin/env python-for-pytorch
2
3 import torch
4 import math, sys, re
5
6 from torch import nn
7 from torch.nn import functional as fn
8
9 from torch import Tensor
10 from torch.autograd import Variable
11 from torch.nn.parameter import Parameter
12 from torch.nn import Module
13
14 ######################################################################
15
16 class Link:
17     def __init__(self, from_node, from_nb, to_node, to_nb):
18         self.from_node = from_node
19         self.from_nb = from_nb
20         self.to_node = to_node
21         self.to_nb = to_nb
22
23 class Node:
24     def __init__(self, id, label):
25         self.id = id
26         self.label = label
27         self.max_in = -1
28         self.max_out = -1
29
30 def slot(node_list, n, k, for_input):
31     if for_input:
32         if node_list[n].max_out > 0:
33             return str(node_list[n].id) + ':input' + str(k)
34         else:
35             return str(node_list[n].id)
36     else:
37         if node_list[n].max_in > 0:
38             return str(node_list[n].id) + ':output' + str(k)
39         else:
40             return str(node_list[n].id)
41
42 def slot_string(k, for_input):
43     result = ''
44
45     if for_input:
46         label = 'input'
47     else:
48         label = 'output'
49
50     if k > 0:
51         if not for_input: result = ' |' + result
52         result +=  ' { <' + label + '0> 0'
53         for j in range(1, k+1):
54             result += " | " + '<' + label + str(j) + '> ' + str(j)
55         result += " } "
56         if for_input: result = result + '| '
57
58     return result
59
60 ######################################################################
61
62 def add_link(node_list, link_list, u, nu, v, nv):
63     link = Link(u, nu, v, nv)
64     link_list.append(link)
65     node_list[u].max_in  = max(node_list[u].max_in,  nu)
66     node_list[v].max_out = max(node_list[u].max_out, nv)
67
68 ######################################################################
69
70 def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
71
72     if not u in node_list:
73         node = Node(len(node_list) + 1,
74                     (u in node_labels and node_labels[u]) or \
75                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
76         node_list[u] = node
77
78         if isinstance(u, torch.autograd.Variable):
79             build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
80             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
81         else:
82             if hasattr(u, 'next_functions'):
83                 i = 0
84                 for v, j in u.next_functions:
85                     build_ag_graph_lists(v, node_labels, out, node_list, link_list)
86                     add_link(node_list, link_list, u, i, v, j)
87                     i += 1
88
89 ######################################################################
90
91 def print_dot(node_list, link_list, out):
92     out.write('digraph{\n')
93
94     for n in node_list:
95         node = node_list[n]
96
97         out.write(
98             '  ' + \
99             str(node.id) + ' [shape=record,label="{ ' + \
100             slot_string(node.max_out, for_input = True) + \
101             node.label + \
102             slot_string(node.max_in, for_input = False) + \
103             ' }"]\n'
104         )
105
106     for n in link_list:
107         out.write('  ' + \
108                   slot(node_list, n.from_node, n.from_nb, for_input = False) + \
109                   ' -> ' + \
110                   slot(node_list, n.to_node, n.to_nb, for_input = True) + \
111                   '\n')
112
113     out.write('}\n')
114
115 ######################################################################
116
117 def save_dot(x, node_labels = {}, out = sys.stdout):
118     node_list = {}
119     link_list = []
120     build_ag_graph_lists(x, node_labels, out, node_list, link_list)
121     print_dot(node_list, link_list, out)
122
123 ######################################################################
124
125 # x = Variable(torch.rand(5))
126 # y = torch.topk(x, 3)
127 # l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
128
129 # save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))