Remove the clone() for node.gradOutput when possible.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 15:50:57 +0000 (16:50 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 15:50:57 +0000 (16:50 +0100)
dagnn.lua

index 0c1d153..de9d29b 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -86,20 +86,20 @@ function DAG:putInOrder()
    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
 end
 
-function DAG:computeGradOutput(gradInputSucc)
-   local gi
+function DAG:updateGradOutput(node)
+   local gradInputSucc = node.gradInputSucc
    if #gradInputSucc == 1 then
-      gi = gradInputSucc[1] -- we avoid a clone()
+      node.gradOutput = gradInputSucc[1]
    elseif #gradInputSucc > 1 then
-      for k = 1, #gradInputSucc do
-         if gi then
-            gi:add(gradInputSucc[k])
-         else
-            gi = gradInputSucc[k]:clone()
-         end
+      if node.gradOutput then
+         node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1])
+      else
+         node.gradOutput = gradInputSucc[1]:clone()
+      end
+      for k = 2, #gradInputSucc do
+         node.gradOutput:add(gradInputSucc[k])
       end
    end
-   return gi
 end
 
 ----------------------------------------------------------------------
@@ -206,7 +206,6 @@ function DAG:updateOutput(input)
    self:nestedApply(
       function(nnm, i)
          self.node[nnm].input = i
-         -- nnm:updateOutput(i)
          self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
       end,
       self.inputModules,
@@ -226,7 +225,6 @@ function DAG:updateOutput(input)
             end
          end
          node.input = i
-         -- nnm:updateOutput(i)
          self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
       end
    end
@@ -246,7 +244,6 @@ function DAG:updateGradInput(input, gradOutput)
       function(nnm, go)
          local node = self.node[nnm]
          node.gradOutput = go
-         -- nnm:updateGradInput(self.node[nnm].input, go)
          self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
       end,
       self.outputModules, gradOutput
@@ -264,11 +261,10 @@ function DAG:updateGradInput(input, gradOutput)
    for k = #self.sorted, 1, -1 do
       local nnm = self.sorted[k]
       local node = self.node[nnm]
-      local pred, gradInputSucc = node.pred, node.gradInputSucc
+      local pred = node.pred
 
-      if #gradInputSucc > 0 then
-         node.gradOutput = self:computeGradOutput(gradInputSucc)
-         -- nnm:updateGradInput(node.input, node.gradOutput)
+      if #node.gradInputSucc > 0 then
+         self:updateGradOutput(node)
          self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
       end
 
@@ -308,7 +304,6 @@ function DAG:accGradParameters(input, gradOutput, scale)
    for k = 1, #self.modules do
       local nnm = self.modules[k]
       local node = self.node[nnm]
-      -- nnm:accGradParameters(node.input, node.gradOutput, scale)
       self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
    end
 end