X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=dagnn.lua;h=484184346a81b6130ab76f5d9a5585396f12bf53;hp=52913ad9992f404277a9b695be590baaa9cdedea;hb=452781856eafd237579e5c90b6e345354df91b42;hpb=be03a73e411d18082a2dd99bff5df45c085017ca diff --git a/dagnn.lua b/dagnn.lua index 52913ad..4841843 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -11,6 +11,7 @@ function DAG:__init() end function DAG:addEdge(a, b) + self.sorted = nil local pred, succ = self.pred, self.succ if not pred[a] and not succ[a] then self:add(a) @@ -25,6 +26,7 @@ function DAG:addEdge(a, b) end function DAG:setInput(i) + self.sorted = nil if torch.type(i) == 'table' then self.inputModules = i for _, m in ipairs(i) do @@ -38,6 +40,7 @@ function DAG:setInput(i) end function DAG:setOutput(o) + self.sorted = nil if torch.type(o) == 'table' then self.outputModules = o for _, m in ipairs(o) do @@ -50,7 +53,11 @@ function DAG:setOutput(o) end end -function DAG:order() +function DAG:sort() + if self.sorted then + return + end + local distance = {} for _, a in pairs(self.inputModules) do @@ -81,12 +88,16 @@ function DAG:order() end function DAG:print() + self:sort() + for i, d in ipairs(self.sorted) do print('#' .. i .. ' -> ' .. torch.type(d)) end end function DAG:updateOutput(input) + self:sort() + if #self.inputModules == 1 then self.inputModules[1]:updateOutput(input) else @@ -120,3 +131,9 @@ function DAG:updateOutput(input) return self.output end + +function DAG:updateGradInput(input, gradOutput) + self:sort() +end + +return DAG