DAG:addEdge now allows to add a bunch of edges in one shot. This allows to use anonym...
authorFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 20:18:11 +0000 (21:18 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 20:18:11 +0000 (21:18 +0100)
README.md
dagnn.lua
test-dagnn.lua

index 3b8d274..e18ea1f 100644 (file)
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@ This package implements a new module nn.DAG which inherits from nn.Container and
 
 #Example#
 
-The typical use is:
+A typical use would be:
 
 ```Lua
 model = nn.DAG()
@@ -11,34 +11,36 @@ model = nn.DAG()
 a = nn.Linear(100, 10)
 b = nn.ReLU()
 c = nn.Linear(10, 15)
-d = nn.Linear(10, 15)
-e = nn.CMulTable()
-f = nn.Linear(15, 15)
+d = nn.CMulTable()
+e = nn.Linear(15, 15)
 
 model:addEdge(a, b)
+model:addEdge(b, nn.Linear(10, 15), nn.ReLU(), d)
 model:addEdge(b, c)
-model:addEdge(b, d)
-model:addEdge(c, e)
-model:addEdge(d, e)
-model:addEdge(d, f)
+model:addEdge(c, d)
+model:addEdge(c, nn.Mul(-1), e)
 
 model:setInput(a)
-model:setOutput({ e, f })
+model:setOutput({ d, e })
 
-input = torch.Tensor(300, 100):uniform()
-output = model:updateOutput(input):clone()
+input = torch.Tensor(30, 100):uniform()
+output = model:updateOutput(input)
 
 ```
 
 which would encode the following graph
 
-                       +--> c ----> e -->
-                      /            /
-                     /            /
-    input --> a --> b ----> d ---+          output
-                             \
+                 +- Linear(10, 10) -> ReLU ---> d -->
+                /                              /
+               /                              /
+    --> a --> b -----------> c --------------+
                               \
-                               +--> f -->
+                               \
+                                +-- Mul(-1) --> e -->
+
+and run a forward pass with a random batch of 30 samples.
+
+Note that DAG:addEdge
 
 #Input and output#
 
index 7fc1018..9202932 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -40,12 +40,19 @@ function DAG:createNode(nnm)
    end
 end
 
-function DAG:addEdge(nnma, nnmb)
+-- The main use should be to add an edge between two modules, but it
+-- can also add a full sequence of modules
+function DAG:addEdge(...)
    self.sorted = nil
-   self:createNode(nnma)
-   self:createNode(nnmb)
-   table.insert(self.node[nnmb].pred, nnma)
-   table.insert(self.node[nnma].succ, nnmb)
+   local prev
+   for _, nnm in pairs({...}) do
+      self:createNode(nnm)
+      if prev then
+         table.insert(self.node[nnm].pred, prev)
+         table.insert(self.node[prev].succ, nnm)
+      end
+      prev = nnm
+   end
 end
 
 -- Apply f on t recursively; use the corresponding element from args
index 1df04e2..8f92ccf 100755 (executable)
@@ -103,13 +103,13 @@ g = nn.CAddTable()
 model = nn.DAG()
 
 model:addEdge(a, b)
-model:addEdge(b, c)
+model:addEdge(b, nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 10), c)
 model:addEdge(b, d)
 model:addEdge(c, e)
 model:addEdge(d, e)
 model:addEdge(d, f)
 model:addEdge(e, g)
-model:addEdge(f, g)
+model:addEdge(f, nn.Mul(-1), g)
 
 model:setInput(a)
 model:setOutput(g)