+ assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput.')
+
+ self:nestedApply(
+ function(nnm, go)
+ local node = self.node[nnm]
+ node.gradOutput = go
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go)
+ end,
+ self.outputModules, gradOutput
+ )
+
+ self:nestedApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
+
+ for _, node in pairs(self.node) do
+ node.gradInputSucc = {}
+ end
+
+ for k = #self.sorted, 1, -1 do
+ local nnm = self.sorted[k]
+ local node = self.node[nnm]
+ local pred = node.pred
+
+ if #node.gradInputSucc > 0 then
+ self:updateGradOutput(node)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput)
+ end
+
+ -- We fill the gradInputSucc of our predecessors
+ if #pred == 1 then
+ table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
+ elseif #pred > 1 then
+ assert(torch.type(nnm.gradInput) == 'table',
+ 'Should have a table gradInput since it has multiple predecessors.')
+ for n = 1, #pred do
+ table.insert(self.node[pred[n]].gradInputSucc, nnm.gradInput[n])
+ end
+ end
+ end
+
+ self.gradInput = self:nestedApply(
+ function(m) return m.gradInput end,
+ self.inputModules
+ )
+
+ return self.gradInput
+end
+
+function DAG:accGradParameters(input, gradOutput, scale)
+ assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters.')
+
+ self:nestedApply(
+ function(nnm, go) self.node[nnm].gradOutput = go end,
+ self.outputModules, gradOutput
+ )
+
+ self:nestedApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
+
+ for k = 1, #self.modules do
+ local nnm = self.modules[k]
+ local node = self.node[nnm]
+ self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
+ end