From be353fdfc2a57172064a024f8cec6015c9d908e5 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 12 Jan 2017 21:18:11 +0100 Subject: [PATCH] DAG:addEdge now allows to add a bunch of edges in one shot. This allows to use anonymous modules when they are used only once. --- README.md | 36 +++++++++++++++++++----------------- dagnn.lua | 17 ++++++++++++----- test-dagnn.lua | 4 ++-- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 3b8d274..e18ea1f 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ This package implements a new module nn.DAG which inherits from nn.Container and #Example# -The typical use is: +A typical use would be: ```Lua model = nn.DAG() @@ -11,34 +11,36 @@ model = nn.DAG() a = nn.Linear(100, 10) b = nn.ReLU() c = nn.Linear(10, 15) -d = nn.Linear(10, 15) -e = nn.CMulTable() -f = nn.Linear(15, 15) +d = nn.CMulTable() +e = nn.Linear(15, 15) model:addEdge(a, b) +model:addEdge(b, nn.Linear(10, 15), nn.ReLU(), d) model:addEdge(b, c) -model:addEdge(b, d) -model:addEdge(c, e) -model:addEdge(d, e) -model:addEdge(d, f) +model:addEdge(c, d) +model:addEdge(c, nn.Mul(-1), e) model:setInput(a) -model:setOutput({ e, f }) +model:setOutput({ d, e }) -input = torch.Tensor(300, 100):uniform() -output = model:updateOutput(input):clone() +input = torch.Tensor(30, 100):uniform() +output = model:updateOutput(input) ``` which would encode the following graph - +--> c ----> e --> - / / - / / - input --> a --> b ----> d ---+ output - \ + +- Linear(10, 10) -> ReLU ---> d --> + / / + / / + --> a --> b -----------> c --------------+ \ - +--> f --> + \ + +-- Mul(-1) --> e --> + +and run a forward pass with a random batch of 30 samples. + +Note that DAG:addEdge #Input and output# diff --git a/dagnn.lua b/dagnn.lua index 7fc1018..9202932 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -40,12 +40,19 @@ function DAG:createNode(nnm) end end -function DAG:addEdge(nnma, nnmb) +-- The main use should be to add an edge between two modules, but it +-- can also add a full sequence of modules +function DAG:addEdge(...) self.sorted = nil - self:createNode(nnma) - self:createNode(nnmb) - table.insert(self.node[nnmb].pred, nnma) - table.insert(self.node[nnma].succ, nnmb) + local prev + for _, nnm in pairs({...}) do + self:createNode(nnm) + if prev then + table.insert(self.node[nnm].pred, prev) + table.insert(self.node[prev].succ, nnm) + end + prev = nnm + end end -- Apply f on t recursively; use the corresponding element from args diff --git a/test-dagnn.lua b/test-dagnn.lua index 1df04e2..8f92ccf 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -103,13 +103,13 @@ g = nn.CAddTable() model = nn.DAG() model:addEdge(a, b) -model:addEdge(b, c) +model:addEdge(b, nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 10), c) model:addEdge(b, d) model:addEdge(c, e) model:addEdge(d, e) model:addEdge(d, f) model:addEdge(e, g) -model:addEdge(f, g) +model:addEdge(f, nn.Mul(-1), g) model:setInput(a) model:setOutput(g) -- 2.20.1