Now prints the input index when a node gets multiple input.
[dagnn.git] / dagnn.lua
index 14cd582..c17347d 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -29,7 +29,7 @@ function DAG:__init()
    self.node = { }
 end
 
--- Apply f on t recursively; use the corresponding element from args
+-- Apply f on t recursively; use the corresponding elements from args
 -- (i.e. same keys) as second parameter to f when available; return
 -- the results from f, organized in a similarly nested table.
 function DAG:nestedApply(f, t, args)
@@ -167,20 +167,25 @@ function DAG:saveDot(filename)
 
    file:write('\n')
 
-   for nnma, node in pairs(self.node) do
+   for nnmb, node in pairs(self.node) do
       file:write(
          '  '
             .. node.index
-            .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]'
+            .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]'
             .. '\n'
       )
 
-      for _, nnmb in pairs(node.succ) do
+      for i, nnma in pairs(node.pred) do
+         local decoration = ''
+         if #node.pred > 1 then
+            decoration = ' [label=\"' .. i .. '\"]'
+         end
          file:write(
             '  '
-               .. node.index
+               .. self.node[nnma].index
                .. ' -> '
                .. self.node[nnmb].index
+               .. decoration
                .. '\n'
          )
       end