The analytic gradient checks out.
[dagnn.git] / dagnn.lua
index 8a02cc6..0f93d95 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -6,7 +6,7 @@ local DAG, parent = torch.class('nn.DAG', 'nn.Container')
 
 function DAG:__init()
    parent.__init(self)
-   -- Nodes are indexed by the module they encompass
+   -- Nodes are indexed by the module they contain
    self.node = { }
 end
 
@@ -27,26 +27,25 @@ function DAG:addEdge(nnma, nnmb)
    table.insert(self.node[nnma].succ, nnmb)
 end
 
--- Apply f on t recursively; use the corresponding a1 and a2 elements
--- (i.e. same keys) as second and third parameters to f when
--- available; return the results from f, organized in a similarly
--- nested table.
-function DAG:nestApply(f, t, a1, a2)
+-- Apply f on t recursively; use the corresponding element from args
+-- (i.e. same keys) as second parameter to f when available; return
+-- the results from f, organized in a similarly nested table.
+function DAG:nestedApply(f, t, args)
    if torch.type(t) == 'table' then
       local result = {}
       for k, s in pairs(t) do
-         result[k] = self:nestApply(f, s, a1 and a1[k], a2 and a2[k])
+         result[k] = self:nestedApply(f, s, args and args[k])
       end
       return result
    else
-      return f(t, a1, a2)
+      return f(t, args)
    end
 end
 
 function DAG:setInput(i)
    self.sorted = nil
    self.inputModules = i
-   self:nestApply(
+   self:nestedApply(
       function(nnm)
          if #self.node[nnm].succ == 0 then
             error('Input modules must have outgoing  edges.')
@@ -62,7 +61,7 @@ end
 function DAG:setOutput(o)
    self.sorted = nil
    self.outputModules = o
-   self:nestApply(
+   self:nestedApply(
       function(nnm)
          if #self.node[nnm].pred == 0 then
             error('Output module must have incoming edges.')
@@ -84,7 +83,7 @@ function DAG:putInOrder()
 
    local distance = {}
 
-   self:nestApply(function(m) distance[m] = 1 end, self.inputModules)
+   self:nestedApply(function(m) distance[m] = 1 end, self.inputModules)
 
    local nc
 
@@ -121,7 +120,7 @@ end
 function DAG:updateOutput(input)
    self:putInOrder()
 
-   self:nestApply(
+   self:nestedApply(
       function(nnm, i)
          self.node[nnm].input = i
          nnm:updateOutput(i)
@@ -147,7 +146,10 @@ function DAG:updateOutput(input)
       end
    end
 
-   self.output = self:nestApply(function(m) return m.output end, self.outputModules)
+   self.output = self:nestedApply(
+      function(m) return m.output end,
+      self.outputModules
+   )
 
    return self.output
 end
@@ -171,12 +173,12 @@ end
 function DAG:updateGradInput(input, gradOutput)
    self:putInOrder()
 
-   self:nestApply(
+   self:nestedApply(
       function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
       self.outputModules, gradOutput
    )
 
-   self:nestApply(
+   self:nestedApply(
       function(nnm, i) self.node[nnm].input = i end,
       self.inputModules, input
    )
@@ -207,9 +209,31 @@ function DAG:updateGradInput(input, gradOutput)
       end
    end
 
-   self.gradInput = self:nestApply(function(m) return m.gradInput end, self.inputModules)
+   self.gradInput = self:nestedApply(function(m) return m.gradInput end, self.inputModules)
 
    return self.gradInput
 end
 
+function DAG:accGradParameters(input, gradOutput, scale)
+   scale = scale or 1
+
+   self:putInOrder()
+
+   self:nestedApply(
+      function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
+      self.outputModules, gradOutput
+   )
+
+   self:nestedApply(
+      function(nnm, i) self.node[nnm].input = i end,
+      self.inputModules, input
+   )
+
+   for k = #self.sorted, 1, -1 do
+      local nnm = self.sorted[k]
+      local node = self.node[nnm]
+      nnm:accGradParameters(node.input, self:computeGradInput(node.gradInputSucc), scale)
+   end
+end
+
 return DAG