X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=67e113d84300cdefd43ab799d77d46a946e21e13;hb=7dc69291661dbdf731f0da2955e8fec5f288cbba;hp=0e7e8b0364ffb2b13a6319c9d3ee13c519892b17;hpb=d3b0a00b9f46d4ef147e8d52b9d02ebdf78ce9d3;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 0e7e8b0..67e113d 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -76,6 +76,10 @@ function DAG:putInOrder() end until nc == 0 + for _, nnm in pairs(self.modules) do + assert(distance[nnm], 'Some modules are not connected to inputs') + end + self.sorted = {} for m, d in pairs(distance) do table.insert(self.sorted, { distance = d, nnm = m }) @@ -86,14 +90,16 @@ function DAG:putInOrder() for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end end --- This accumulate x in a where they are both nested tables of --- tensors. If first is true, set a = x. +-- This accumulates x in a where they are both nested tables of +-- tensors. If first is true, set a = x. Behavior is undefined if a +-- and x do not have the exact same structure. function DAG:nestedAccTensor(a, x, first) if torch.type(x) == 'table' then - a = a or {} + local b = {} for i in pairs(x) do - a[i] = self:nestedAccTensor(a[i], x[i], first) + b[i] = self:nestedAccTensor(a[i], x[i], first) end + a = b else if first then if a then @@ -141,10 +147,10 @@ function DAG:setInput(i) self:nestedApply( function(nnm) if #self.node[nnm].succ == 0 then - error('Input modules must have outgoing edges.') + error('Input modules must have outgoing edges.') end if #self.node[nnm].pred > 0 then - error('Input modules cannog have incoming edges.') + error('Input modules cannot have incoming edges.') end end, self.inputModules @@ -222,8 +228,9 @@ function DAG:updateOutput(input) self:nestedApply( function(nnm, i) - self.node[nnm].input = i - self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i) + local node = self.node[nnm] + node.input = i + self:rethrowErrors(nnm, node.index, 'updateOutput', i) end, self.inputModules, input @@ -242,7 +249,7 @@ function DAG:updateOutput(input) end end node.input = i - self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i) + self:rethrowErrors(nnm, node.index, 'updateOutput', i) end end @@ -261,7 +268,7 @@ function DAG:updateGradInput(input, gradOutput) function(nnm, go) local node = self.node[nnm] node.gradOutput = go - self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go) + self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go) end, self.outputModules, gradOutput ) @@ -282,7 +289,7 @@ function DAG:updateGradInput(input, gradOutput) if #node.gradInputSucc > 0 then self:updateGradOutput(node) - self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput) + self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput) end -- We fill the gradInputSucc of our predecessors @@ -304,8 +311,6 @@ function DAG:updateGradInput(input, gradOutput) end function DAG:accGradParameters(input, gradOutput, scale) - scale = scale or 1 - assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters') self:nestedApply(