From 56a476ee19396d0e7f186b238dc7d013000acb59 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 13 Jan 2017 16:50:57 +0100 Subject: [PATCH] Remove the clone() for node.gradOutput when possible. --- dagnn.lua | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/dagnn.lua b/dagnn.lua index 0c1d153..de9d29b 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -86,20 +86,20 @@ function DAG:putInOrder() for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end end -function DAG:computeGradOutput(gradInputSucc) - local gi +function DAG:updateGradOutput(node) + local gradInputSucc = node.gradInputSucc if #gradInputSucc == 1 then - gi = gradInputSucc[1] -- we avoid a clone() + node.gradOutput = gradInputSucc[1] elseif #gradInputSucc > 1 then - for k = 1, #gradInputSucc do - if gi then - gi:add(gradInputSucc[k]) - else - gi = gradInputSucc[k]:clone() - end + if node.gradOutput then + node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1]) + else + node.gradOutput = gradInputSucc[1]:clone() + end + for k = 2, #gradInputSucc do + node.gradOutput:add(gradInputSucc[k]) end end - return gi end ---------------------------------------------------------------------- @@ -206,7 +206,6 @@ function DAG:updateOutput(input) self:nestedApply( function(nnm, i) self.node[nnm].input = i - -- nnm:updateOutput(i) self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i) end, self.inputModules, @@ -226,7 +225,6 @@ function DAG:updateOutput(input) end end node.input = i - -- nnm:updateOutput(i) self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i) end end @@ -246,7 +244,6 @@ function DAG:updateGradInput(input, gradOutput) function(nnm, go) local node = self.node[nnm] node.gradOutput = go - -- nnm:updateGradInput(self.node[nnm].input, go) self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go) end, self.outputModules, gradOutput @@ -264,11 +261,10 @@ function DAG:updateGradInput(input, gradOutput) for k = #self.sorted, 1, -1 do local nnm = self.sorted[k] local node = self.node[nnm] - local pred, gradInputSucc = node.pred, node.gradInputSucc + local pred = node.pred - if #gradInputSucc > 0 then - node.gradOutput = self:computeGradOutput(gradInputSucc) - -- nnm:updateGradInput(node.input, node.gradOutput) + if #node.gradInputSucc > 0 then + self:updateGradOutput(node) self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput) end @@ -308,7 +304,6 @@ function DAG:accGradParameters(input, gradOutput, scale) for k = 1, #self.modules do local nnm = self.modules[k] local node = self.node[nnm] - -- nnm:accGradParameters(node.input, node.gradOutput, scale) self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale) end end -- 2.20.1