function DAG:__init()
parent.__init(self)
- -- Nodes are indexed by the module they encompass
+ -- Nodes are indexed by the module they contain
self.node = { }
end
table.insert(self.node[nnma].succ, nnmb)
end
--- Apply f on t recursively; use the corresponding a1 and a2 elements
--- (i.e. same keys) as second and third parameters to f when
--- available; return the results from f, organized in a similarly
--- nested table.
-function DAG:nestApply(f, t, a1, a2)
+-- Apply f on t recursively; use the corresponding element from args
+-- (i.e. same keys) as second parameter to f when available; return
+-- the results from f, organized in a similarly nested table.
+function DAG:nestedApply(f, t, args)
if torch.type(t) == 'table' then
local result = {}
for k, s in pairs(t) do
- result[k] = self:nestApply(f, s, a1 and a1[k], a2 and a2[k])
+ result[k] = self:nestedApply(f, s, args and args[k])
end
return result
else
- return f(t, a1, a2)
+ return f(t, args)
end
end
function DAG:setInput(i)
self.sorted = nil
self.inputModules = i
- self:nestApply(
+ self:nestedApply(
function(nnm)
if #self.node[nnm].succ == 0 then
error('Input modules must have outgoing edges.')
function DAG:setOutput(o)
self.sorted = nil
self.outputModules = o
- self:nestApply(
+ self:nestedApply(
function(nnm)
if #self.node[nnm].pred == 0 then
error('Output module must have incoming edges.')
local distance = {}
- self:nestApply(function(m) distance[m] = 1 end, self.inputModules)
+ self:nestedApply(function(m) distance[m] = 1 end, self.inputModules)
local nc
function DAG:updateOutput(input)
self:putInOrder()
- self:nestApply(
+ self:nestedApply(
function(nnm, i)
self.node[nnm].input = i
nnm:updateOutput(i)
end
end
- self.output = self:nestApply(function(m) return m.output end, self.outputModules)
+ self.output = self:nestedApply(
+ function(m) return m.output end,
+ self.outputModules
+ )
return self.output
end
function DAG:updateGradInput(input, gradOutput)
self:putInOrder()
- self:nestApply(
+ self:nestedApply(
function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
self.outputModules, gradOutput
)
- self:nestApply(
+ self:nestedApply(
function(nnm, i) self.node[nnm].input = i end,
self.inputModules, input
)
end
end
- self.gradInput = self:nestApply(function(m) return m.gradInput end, self.inputModules)
+ self.gradInput = self:nestedApply(function(m) return m.gradInput end, self.inputModules)
return self.gradInput
end
+function DAG:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
+
+ self:putInOrder()
+
+ self:nestedApply(
+ function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
+ self.outputModules, gradOutput
+ )
+
+ self:nestedApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
+
+ for k = #self.sorted, 1, -1 do
+ local nnm = self.sorted[k]
+ local node = self.node[nnm]
+ nnm:accGradParameters(node.input, self:computeGradInput(node.gradInputSucc), scale)
+ end
+end
+
return DAG