X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=c17347da8a18f6774dc975a68d5365491bdcd781;hb=0b9befca4f43c39e01801efc417069ef709b9271;hp=14cd5821c2064ac74b50dde101449d10d8f2274e;hpb=fe54a7c5c8425ee9783d82e16a42924e23add457;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 14cd582..c17347d 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -29,7 +29,7 @@ function DAG:__init() self.node = { } end --- Apply f on t recursively; use the corresponding element from args +-- Apply f on t recursively; use the corresponding elements from args -- (i.e. same keys) as second parameter to f when available; return -- the results from f, organized in a similarly nested table. function DAG:nestedApply(f, t, args) @@ -167,20 +167,25 @@ function DAG:saveDot(filename) file:write('\n') - for nnma, node in pairs(self.node) do + for nnmb, node in pairs(self.node) do file:write( ' ' .. node.index - .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]' + .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]' .. '\n' ) - for _, nnmb in pairs(node.succ) do + for i, nnma in pairs(node.pred) do + local decoration = '' + if #node.pred > 1 then + decoration = ' [label=\"' .. i .. '\"]' + end file:write( ' ' - .. node.index + .. self.node[nnma].index .. ' -> ' .. self.node[nnmb].index + .. decoration .. '\n' ) end