X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=dagnn.lua;h=9202932da2950905fc49734c8762226e0e4d6f28;hp=7fc1018f8dd7f3f88021c914608c5af1f5a710a5;hb=be353fdfc2a57172064a024f8cec6015c9d908e5;hpb=116fbcd681f9e097f7acd89f61a15c6b7bd113ce diff --git a/dagnn.lua b/dagnn.lua index 7fc1018..9202932 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -40,12 +40,19 @@ function DAG:createNode(nnm) end end -function DAG:addEdge(nnma, nnmb) +-- The main use should be to add an edge between two modules, but it +-- can also add a full sequence of modules +function DAG:addEdge(...) self.sorted = nil - self:createNode(nnma) - self:createNode(nnmb) - table.insert(self.node[nnmb].pred, nnma) - table.insert(self.node[nnma].succ, nnmb) + local prev + for _, nnm in pairs({...}) do + self:createNode(nnm) + if prev then + table.insert(self.node[nnm].pred, prev) + table.insert(self.node[prev].succ, nnm) + end + prev = nnm + end end -- Apply f on t recursively; use the corresponding element from args