Renamed DAG:addEdge to DAG:connect
[dagnn.git] / dagnn.lua
index 7fc1018..9203264 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -40,12 +40,19 @@ function DAG:createNode(nnm)
    end
 end
 
-function DAG:addEdge(nnma, nnmb)
+-- The main use should be to add an edge between two modules, but it
+-- can also add a full sequence of modules
+function DAG:connect(...)
    self.sorted = nil
-   self:createNode(nnma)
-   self:createNode(nnmb)
-   table.insert(self.node[nnmb].pred, nnma)
-   table.insert(self.node[nnma].succ, nnmb)
+   local prev
+   for _, nnm in pairs({...}) do
+      self:createNode(nnm)
+      if prev then
+         table.insert(self.node[nnm].pred, prev)
+         table.insert(self.node[prev].succ, nnm)
+      end
+      prev = nnm
+   end
 end
 
 -- Apply f on t recursively; use the corresponding element from args
@@ -254,3 +261,35 @@ function DAG:accGradParameters(input, gradOutput, scale)
 end
 
 ----------------------------------------------------------------------
+
+function DAG:dot(filename)
+   local file = (filename and io.open(filename, 'w')) or io.stdout
+
+   file:write('digraph {\n')
+
+   file:write('\n')
+
+   for nnma, node in pairs(self.node) do
+      file:write(
+         '  '
+            .. node.index
+            .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]'
+            .. '\n'
+      )
+
+      for _, nnmb in pairs(node.succ) do
+         file:write(
+            '  '
+               .. node.index
+               .. ' -> '
+               .. self.node[nnmb].index
+               .. '\n'
+         )
+      end
+
+      file:write('\n')
+   end
+
+   file:write('}\n')
+
+end