X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=c6d54ad5a057d72c4029ee6c84b87a6e03a30e7d;hb=0a630b54355382dfa68c0f3d51729bad0b4c58e6;hp=7fc1018f8dd7f3f88021c914608c5af1f5a710a5;hpb=063f198047f0202fa921aa09b772369b14ae8be2;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 7fc1018..c6d54ad 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -40,12 +40,19 @@ function DAG:createNode(nnm) end end -function DAG:addEdge(nnma, nnmb) +-- The main use should be to add an edge between two modules, but it +-- can also add a full sequence of modules +function DAG:addEdge(...) self.sorted = nil - self:createNode(nnma) - self:createNode(nnmb) - table.insert(self.node[nnmb].pred, nnma) - table.insert(self.node[nnma].succ, nnmb) + local prev + for _, nnm in pairs({...}) do + self:createNode(nnm) + if prev then + table.insert(self.node[nnm].pred, prev) + table.insert(self.node[prev].succ, nnm) + end + prev = nnm + end end -- Apply f on t recursively; use the corresponding element from args @@ -254,3 +261,35 @@ function DAG:accGradParameters(input, gradOutput, scale) end ---------------------------------------------------------------------- + +function DAG:dot(filename) + local file = (filename and io.open(filename, 'w')) or io.stdout + + file:write('digraph {\n') + + file:write('\n') + + for nnma, node in pairs(self.node) do + file:write( + ' ' + .. node.index + .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]' + .. '\n' + ) + + for _, nnmb in pairs(node.succ) do + file:write( + ' ' + .. node.index + .. ' -> ' + .. self.node[nnmb].index + .. '\n' + ) + end + + file:write('\n') + end + + file:write('}\n') + +end