From 0a630b54355382dfa68c0f3d51729bad0b4c58e6 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 12 Jan 2017 22:35:26 +0100 Subject: [PATCH] Added DAG:dot() to generate a dot file for visualization. --- dagnn.lua | 32 ++++++++++++++++++++++++++++++++ test-dagnn.lua | 2 ++ 2 files changed, 34 insertions(+) diff --git a/dagnn.lua b/dagnn.lua index 9202932..c6d54ad 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -261,3 +261,35 @@ function DAG:accGradParameters(input, gradOutput, scale) end ---------------------------------------------------------------------- + +function DAG:dot(filename) + local file = (filename and io.open(filename, 'w')) or io.stdout + + file:write('digraph {\n') + + file:write('\n') + + for nnma, node in pairs(self.node) do + file:write( + ' ' + .. node.index + .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]' + .. '\n' + ) + + for _, nnmb in pairs(node.succ) do + file:write( + ' ' + .. node.index + .. ' -> ' + .. self.node[nnmb].index + .. '\n' + ) + end + + file:write('\n') + end + + file:write('}\n') + +end diff --git a/test-dagnn.lua b/test-dagnn.lua index 3dea310..53302fd 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -108,3 +108,5 @@ local output = model:updateOutput(input):clone() output:uniform() print('Error = ' .. checkGrad(model, nn.MSECriterion(), input, output)) + +model:dot('/tmp/graph.dot') -- 2.20.1