Updated the headers.
[dagnn.git] / dagnn.lua
index 8a02cc6..158ef78 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -1,4 +1,23 @@
 
+--[[
+
+   Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
+   Written by Francois Fleuret <francois.fleuret@idiap.ch>
+
+   This file is free software: you can redistribute it and/or modify
+   it under the terms of the GNU General Public License version 3 as
+   published by the Free Software Foundation.
+
+   It is distributed in the hope that it will be useful, but WITHOUT
+   ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+   or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
+   License for more details.
+
+   You should have received a copy of the GNU General Public License
+   along with this file.  If not, see <http://www.gnu.org/licenses/>.
+
+]]--
+
 require 'torch'
 require 'nn'
 
@@ -6,7 +25,7 @@ local DAG, parent = torch.class('nn.DAG', 'nn.Container')
 
 function DAG:__init()
    parent.__init(self)
-   -- Nodes are indexed by the module they encompass
+   -- Nodes are indexed by the module they contain
    self.node = { }
 end
 
@@ -27,26 +46,25 @@ function DAG:addEdge(nnma, nnmb)
    table.insert(self.node[nnma].succ, nnmb)
 end
 
--- Apply f on t recursively; use the corresponding a1 and a2 elements
--- (i.e. same keys) as second and third parameters to f when
--- available; return the results from f, organized in a similarly
--- nested table.
-function DAG:nestApply(f, t, a1, a2)
+-- Apply f on t recursively; use the corresponding element from args
+-- (i.e. same keys) as second parameter to f when available; return
+-- the results from f, organized in a similarly nested table.
+function DAG:nestedApply(f, t, args)
    if torch.type(t) == 'table' then
       local result = {}
       for k, s in pairs(t) do
-         result[k] = self:nestApply(f, s, a1 and a1[k], a2 and a2[k])
+         result[k] = self:nestedApply(f, s, args and args[k])
       end
       return result
    else
-      return f(t, a1, a2)
+      return f(t, args)
    end
 end
 
 function DAG:setInput(i)
    self.sorted = nil
    self.inputModules = i
-   self:nestApply(
+   self:nestedApply(
       function(nnm)
          if #self.node[nnm].succ == 0 then
             error('Input modules must have outgoing  edges.')
@@ -62,7 +80,7 @@ end
 function DAG:setOutput(o)
    self.sorted = nil
    self.outputModules = o
-   self:nestApply(
+   self:nestedApply(
       function(nnm)
          if #self.node[nnm].pred == 0 then
             error('Output module must have incoming edges.')
@@ -84,7 +102,7 @@ function DAG:putInOrder()
 
    local distance = {}
 
-   self:nestApply(function(m) distance[m] = 1 end, self.inputModules)
+   self:nestedApply(function(m) distance[m] = 1 end, self.inputModules)
 
    local nc
 
@@ -121,7 +139,7 @@ end
 function DAG:updateOutput(input)
    self:putInOrder()
 
-   self:nestApply(
+   self:nestedApply(
       function(nnm, i)
          self.node[nnm].input = i
          nnm:updateOutput(i)
@@ -147,7 +165,10 @@ function DAG:updateOutput(input)
       end
    end
 
-   self.output = self:nestApply(function(m) return m.output end, self.outputModules)
+   self.output = self:nestedApply(
+      function(m) return m.output end,
+      self.outputModules
+   )
 
    return self.output
 end
@@ -171,12 +192,12 @@ end
 function DAG:updateGradInput(input, gradOutput)
    self:putInOrder()
 
-   self:nestApply(
+   self:nestedApply(
       function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
       self.outputModules, gradOutput
    )
 
-   self:nestApply(
+   self:nestedApply(
       function(nnm, i) self.node[nnm].input = i end,
       self.inputModules, input
    )
@@ -207,9 +228,29 @@ function DAG:updateGradInput(input, gradOutput)
       end
    end
 
-   self.gradInput = self:nestApply(function(m) return m.gradInput end, self.inputModules)
+   self.gradInput = self:nestedApply(function(m) return m.gradInput end, self.inputModules)
 
    return self.gradInput
 end
 
-return DAG
+function DAG:accGradParameters(input, gradOutput, scale)
+   scale = scale or 1
+
+   self:putInOrder()
+
+   self:nestedApply(
+      function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
+      self.outputModules, gradOutput
+   )
+
+   self:nestedApply(
+      function(nnm, i) self.node[nnm].input = i end,
+      self.inputModules, input
+   )
+
+   for k = #self.sorted, 1, -1 do
+      local nnm = self.sorted[k]
+      local node = self.node[nnm]
+      nnm:accGradParameters(node.input, self:computeGradInput(node.gradInputSucc), scale)
+   end
+end