Update.
[agtree2dot.git] / agtree2dot.py
index 2d89af5..2dcc1d7 100755 (executable)
@@ -1,15 +1,22 @@
-#!/usr/bin/env python-for-pytorch
+#########################################################################
+# This program is free software: you can redistribute it and/or modify  #
+# it under the terms of the version 3 of the GNU General Public License #
+# as published by the Free Software Foundation.                         #
+#                                                                       #
+# This program is distributed in the hope that it will be useful, but   #
+# WITHOUT ANY WARRANTY; without even the implied warranty of            #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      #
+# General Public License for more details.                              #
+#                                                                       #
+# You should have received a copy of the GNU General Public License     #
+# along with this program. If not, see <http://www.gnu.org/licenses/>.  #
+#                                                                       #
+# Written by and Copyright (C) Francois Fleuret                         #
+# Contact <francois.fleuret@idiap.ch> for comments & bug reports        #
+#########################################################################
 
 import torch
-import math, sys, re
-
-from torch import nn
-from torch.nn import functional as fn
-
-from torch import Tensor
-from torch.autograd import Variable
-from torch.nn.parameter import Parameter
-from torch.nn import Module
+import sys, re
 
 ######################################################################
 
@@ -50,7 +57,7 @@ def slot_string(k, for_input):
     if k > 0:
         if not for_input: result = ' |' + result
         result +=  ' { <' + label + '0> 0'
-        for j in range(1, k+1):
+        for j in range(1, k + 1):
             result += " | " + '<' + label + str(j) + '> ' + str(j)
         result += " } "
         if for_input: result = result + '| '
@@ -60,31 +67,34 @@ def slot_string(k, for_input):
 ######################################################################
 
 def add_link(node_list, link_list, u, nu, v, nv):
-    link = Link(u, nu, v, nv)
-    link_list.append(link)
-    node_list[u].max_in  = max(node_list[u].max_in,  nu)
-    node_list[v].max_out = max(node_list[u].max_out, nv)
+    if u is not None and v is not None:
+        link = Link(u, nu, v, nv)
+        link_list.append(link)
+        node_list[u].max_in  = max(node_list[u].max_in,  nu)
+        node_list[v].max_out = max(node_list[v].max_out, nv)
 
 ######################################################################
 
-def build_ag_graph_lists(u, node_labels, out, node_list, link_list):
+def fill_graph_lists(u, node_labels, node_list, link_list):
 
-    if not u in node_list:
+    if u is not None and not u in node_list:
         node = Node(len(node_list) + 1,
                     (u in node_labels and node_labels[u]) or \
                     re.search('<class \'(.*\.|)([a-zA-Z0-9_]*)\'>', str(type(u))).group(2))
         node_list[u] = node
 
-        if isinstance(u, torch.autograd.Variable):
-            build_ag_graph_lists(u.grad_fn, node_labels, out, node_list, link_list)
+        if hasattr(u, 'grad_fn'):
+            fill_graph_lists(u.grad_fn, node_labels, node_list, link_list)
             add_link(node_list, link_list, u, 0, u.grad_fn, 0)
-        else:
-            if hasattr(u, 'next_functions'):
-                i = 0
-                for v, j in u.next_functions:
-                    build_ag_graph_lists(v, node_labels, out, node_list, link_list)
-                    add_link(node_list, link_list, u, i, v, j)
-                    i += 1
+
+        if hasattr(u, 'variable'):
+            fill_graph_lists(u.variable, node_labels, node_list, link_list)
+            add_link(node_list, link_list, u, 0, u.variable, 0)
+
+        if hasattr(u, 'next_functions'):
+            for i, (v, j) in enumerate(u.next_functions):
+                fill_graph_lists(v, node_labels, node_list, link_list)
+                add_link(node_list, link_list, u, i, v, j)
 
 ######################################################################
 
@@ -94,14 +104,22 @@ def print_dot(node_list, link_list, out):
     for n in node_list:
         node = node_list[n]
 
-        out.write(
-            '  ' + \
-            str(node.id) + ' [shape=record,label="{ ' + \
-            slot_string(node.max_out, for_input = True) + \
-            node.label + \
-            slot_string(node.max_in, for_input = False) + \
-            ' }"]\n'
-        )
+        if isinstance(n, torch.autograd.Variable):
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=note,style=filled, fillcolor="#e0e0ff",label="' + \
+                node.label + ' ' + re.search('torch\.Size\((.*)\)', str(n.data.size())).group(1) + \
+                '"]\n'
+            )
+        else:
+            out.write(
+                '  ' + \
+                str(node.id) + ' [shape=record,style=filled, fillcolor="#f0f0f0",label="{ ' + \
+                slot_string(node.max_out, for_input = True) + \
+                node.label + \
+                slot_string(node.max_in, for_input = False) + \
+                ' }"]\n'
+            )
 
     for n in link_list:
         out.write('  ' + \
@@ -115,15 +133,8 @@ def print_dot(node_list, link_list, out):
 ######################################################################
 
 def save_dot(x, node_labels = {}, out = sys.stdout):
-    node_list = {}
-    link_list = []
-    build_ag_graph_lists(x, node_labels, out, node_list, link_list)
+    node_list, link_list = {}, []
+    fill_graph_lists(x, node_labels, node_list, link_list)
     print_dot(node_list, link_list, out)
 
 ######################################################################
-
-# x = Variable(torch.rand(5))
-# y = torch.topk(x, 3)
-# l = torch.sqrt(torch.norm(y[0]) + torch.norm(5.0 * y[1].float()))
-
-# save_dot(l, { l: 'variable l' }, open('/tmp/test.dot', 'w'))