Update.
authorFrancois Fleuret <francois@fleuret.org>
Sat, 14 Jan 2017 16:04:06 +0000 (17:04 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Sat, 14 Jan 2017 16:04:06 +0000 (17:04 +0100)
dagnn.lua

index 0073e39..5921c05 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -86,14 +86,16 @@ function DAG:putInOrder()
    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
 end
 
--- This accumulate x in a where they are both nested tables of
--- tensors. If first is true, set a = x.
+-- This accumulates x in a where they are both nested tables of
+-- tensors. If first is true, set a = x. Behavior is undefined if a
+-- and x do not have the exact same structure.
 function DAG:nestedAccTensor(a, x, first)
    if torch.type(x) == 'table' then
-      a = a or {}
+      local b = {}
       for i in pairs(x) do
-         a[i] = self:nestedAccTensor(a[i], x[i], first)
+         b[i] = self:nestedAccTensor(a[i], x[i], first)
       end
+      a = b
    else
       if first then
          if a then
@@ -222,8 +224,9 @@ function DAG:updateOutput(input)
 
    self:nestedApply(
       function(nnm, i)
-         self.node[nnm].input = i
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+         local node = self.node[nnm]
+         node.input = i
+         self:rethrowErrors(nnm, node.index, 'updateOutput', i)
       end,
       self.inputModules,
       input
@@ -242,7 +245,7 @@ function DAG:updateOutput(input)
             end
          end
          node.input = i
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+         self:rethrowErrors(nnm, node.index, 'updateOutput', i)
       end
    end
 
@@ -261,7 +264,7 @@ function DAG:updateGradInput(input, gradOutput)
       function(nnm, go)
          local node = self.node[nnm]
          node.gradOutput = go
-         self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go)
       end,
       self.outputModules, gradOutput
    )
@@ -282,7 +285,7 @@ function DAG:updateGradInput(input, gradOutput)
 
       if #node.gradInputSucc > 0 then
          self:updateGradOutput(node)
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput)
       end
 
       -- We fill the gradInputSucc of our predecessors
@@ -304,8 +307,6 @@ function DAG:updateGradInput(input, gradOutput)
 end
 
 function DAG:accGradParameters(input, gradOutput, scale)
-   scale = scale or 1
-
    assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
 
    self:nestedApply(