- for _, d in ipairs(self.sorted) do
- if self.pred[d] then
- if #self.pred[d] == 1 then
- d:updateOutput(self.pred[d][1].output)
- elseif #self.pred[d] > 1 then
- local c = {}
- for k = 1, #self.pred[d] do
- c[k] = self.pred[d][k].output
+ self.output = self:nestApply(function(m) return m.output end, self.outputModules)
+
+ return self.output
+end
+
+function DAG:updateGradInput(input, gradOutput)
+ self:putInOrder()
+
+ self:nestApply(
+ function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
+ self.outputModules, gradOutput
+ )
+
+ for _, node in pairs(self.node) do
+ node.gradInputSucc = {}
+ end
+
+ for k = #self.sorted, 1, -1 do
+ local nnm = self.sorted[k]
+ local node = self.node[nnm]
+ local pred, succ, gradInputSucc = node.pred, node.succ, node.gradInputSucc
+
+ if #gradInputSucc > 0 then
+ -- We update nnm:gradInput
+ local gi
+ if #gradInputSucc == 1 then
+ gi = gradInputSucc[1] -- we avoid a clone()
+ elseif #gradInputSucc > 1 then
+ for k = 1, #gradInputSucc do
+ if gi then
+ gi:add(gradInputSucc[k])
+ else
+ gi = gradInputSucc[k]:clone()
+ end