Added DAG:setLabel to add a label to a module in the graph.
[dagnn.git] / dagnn.lua
index 1f45b2a..f9d6ff9 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -148,6 +148,10 @@ function DAG:connect(...)
    end
 end
 
+function DAG:setLabel(nnm, label)
+   self.node[nnm].label = label
+end
+
 function DAG:setInput(i)
    self.sorted = nil
    self.inputModules = i
@@ -176,7 +180,11 @@ function DAG:print()
    self:putInOrder()
 
    for i, d in ipairs(self.sorted) do
-      print('#' .. i .. ' -> ' .. torch.type(d))
+      local decoration = ''
+      if self.node[d].label then
+         decoration = ' [' .. self.node[d].label .. ']'
+      end
+      print('#' .. i .. ' -> ' .. torch.type(d) .. decoration)
    end
 end
 
@@ -185,15 +193,33 @@ end
 function DAG:saveDot(filename)
    local file = (filename and io.open(filename, 'w')) or io.stdout
 
+   local function writeNestedCluster(prefix, list, indent)
+      local indent = indent or ''
+      if torch.type(list) == 'table' then
+         file:write(indent .. '  subgraph cluster_' .. prefix .. ' {\n');
+         for k, x in pairs(list) do
+            writeNestedCluster(prefix .. '_' .. k, x, '  ' .. indent)
+         end
+         file:write(indent .. '  }\n');
+      else
+         file:write(indent .. '  ' .. self.node[list].index .. ' [color=red]\n')
+      end
+   end
+
    file:write('digraph {\n')
 
    file:write('\n')
 
+   writeNestedCluster('input', self.inputModules)
+   writeNestedCluster('output', self.outputModules)
+
+   file:write('\n')
+
    for nnmb, node in pairs(self.node) do
       file:write(
          '  '
             .. node.index
-            .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]'
+            .. ' [shape=box,label=\"' .. (self.node[nnmb].label or torch.type(nnmb)) .. '\"]'
             .. '\n'
       )