X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=dagnn.lua;h=65a30e212126532f8796f6c1e35208a018f421fe;hp=484184346a81b6130ab76f5d9a5585396f12bf53;hb=682b76200f755f5f16477e086056a86cafdea1cd;hpb=452781856eafd237579e5c90b6e345354df91b42 diff --git a/dagnn.lua b/dagnn.lua index 4841843..65a30e2 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -25,32 +25,26 @@ function DAG:addEdge(a, b) succ[a][#succ[a] + 1] = b end -function DAG:setInput(i) - self.sorted = nil - if torch.type(i) == 'table' then - self.inputModules = i - for _, m in ipairs(i) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end +function DAG:applyOnModules(f, t1, t2) + if torch.type(t1) == 'table' then + local result = {} + for k, s in pairs(t1) do + result[k] = self:applyOnModules(f, s, t2 and t2[k]) end + return result else - self:setInput({ i }) + return f(t1, t2) end end +function DAG:setInput(i) + self.sorted = nil + self.inputModules = i +end + function DAG:setOutput(o) self.sorted = nil - if torch.type(o) == 'table' then - self.outputModules = o - for _, m in ipairs(o) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end - end - else - self:setOutput({ o }) - end + self.outputModules = o end function DAG:sort() @@ -60,9 +54,7 @@ function DAG:sort() local distance = {} - for _, a in pairs(self.inputModules) do - distance[a] = 1 - end + self:applyOnModules(function(m) distance[m] = 1 end, self.inputModules) local nc @@ -98,13 +90,7 @@ end function DAG:updateOutput(input) self:sort() - if #self.inputModules == 1 then - self.inputModules[1]:updateOutput(input) - else - for i, d in ipairs(self.inputModules) do - d:updateOutput(input[i]) - end - end + self:applyOnModules(function(m, i) m:updateOutput(i) end, self.inputModules, input) for _, d in ipairs(self.sorted) do if self.pred[d] then @@ -120,14 +106,7 @@ function DAG:updateOutput(input) end end - if #self.outputModules == 1 then - self.output = self.outputModules[1].output - else - self.output = { } - for i, d in ipairs(self.outputModules) do - self.output[i] = d.output - end - end + self.output = self:applyOnModules(function(m) return m.output end, self.outputModules) return self.output end