Added a check that all nodes are connected to the inputs.
[dagnn.git] / dagnn.lua
index 0073e39..67e113d 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -76,6 +76,10 @@ function DAG:putInOrder()
       end
    until nc == 0
 
+   for _, nnm in pairs(self.modules) do
+      assert(distance[nnm], 'Some modules are not connected to inputs')
+   end
+
    self.sorted = {}
    for m, d in pairs(distance) do
       table.insert(self.sorted, { distance = d, nnm = m })
@@ -86,14 +90,16 @@ function DAG:putInOrder()
    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
 end
 
--- This accumulate x in a where they are both nested tables of
--- tensors. If first is true, set a = x.
+-- This accumulates x in a where they are both nested tables of
+-- tensors. If first is true, set a = x. Behavior is undefined if a
+-- and x do not have the exact same structure.
 function DAG:nestedAccTensor(a, x, first)
    if torch.type(x) == 'table' then
-      a = a or {}
+      local b = {}
       for i in pairs(x) do
-         a[i] = self:nestedAccTensor(a[i], x[i], first)
+         b[i] = self:nestedAccTensor(a[i], x[i], first)
       end
+      a = b
    else
       if first then
          if a then
@@ -141,7 +147,7 @@ function DAG:setInput(i)
    self:nestedApply(
       function(nnm)
          if #self.node[nnm].succ == 0 then
-            error('Input modules must have outgoing  edges.')
+            error('Input modules must have outgoing edges.')
          end
          if #self.node[nnm].pred > 0 then
             error('Input modules cannot have incoming edges.')
@@ -222,8 +228,9 @@ function DAG:updateOutput(input)
 
    self:nestedApply(
       function(nnm, i)
-         self.node[nnm].input = i
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+         local node = self.node[nnm]
+         node.input = i
+         self:rethrowErrors(nnm, node.index, 'updateOutput', i)
       end,
       self.inputModules,
       input
@@ -242,7 +249,7 @@ function DAG:updateOutput(input)
             end
          end
          node.input = i
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+         self:rethrowErrors(nnm, node.index, 'updateOutput', i)
       end
    end
 
@@ -261,7 +268,7 @@ function DAG:updateGradInput(input, gradOutput)
       function(nnm, go)
          local node = self.node[nnm]
          node.gradOutput = go
-         self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go)
       end,
       self.outputModules, gradOutput
    )
@@ -282,7 +289,7 @@ function DAG:updateGradInput(input, gradOutput)
 
       if #node.gradInputSucc > 0 then
          self:updateGradOutput(node)
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput)
       end
 
       -- We fill the gradInputSucc of our predecessors
@@ -304,8 +311,6 @@ function DAG:updateGradInput(input, gradOutput)
 end
 
 function DAG:accGradParameters(input, gradOutput, scale)
-   scale = scale or 1
-
    assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
 
    self:nestedApply(