Update.
[dagnn.git] / dagnn.lua
index 52913ad..4841843 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -11,6 +11,7 @@ function DAG:__init()
 end
 
 function DAG:addEdge(a, b)
+   self.sorted = nil
    local pred, succ = self.pred, self.succ
    if not pred[a] and not succ[a] then
       self:add(a)
@@ -25,6 +26,7 @@ function DAG:addEdge(a, b)
 end
 
 function DAG:setInput(i)
+   self.sorted = nil
    if torch.type(i) == 'table' then
       self.inputModules = i
       for _, m in ipairs(i) do
@@ -38,6 +40,7 @@ function DAG:setInput(i)
 end
 
 function DAG:setOutput(o)
+   self.sorted = nil
    if torch.type(o) == 'table' then
       self.outputModules = o
       for _, m in ipairs(o) do
@@ -50,7 +53,11 @@ function DAG:setOutput(o)
    end
 end
 
-function DAG:order()
+function DAG:sort()
+   if self.sorted then
+      return
+   end
+
    local distance = {}
 
    for _, a in pairs(self.inputModules) do
@@ -81,12 +88,16 @@ function DAG:order()
 end
 
 function DAG:print()
+   self:sort()
+
    for i, d in ipairs(self.sorted) do
       print('#' .. i .. ' -> ' .. torch.type(d))
    end
 end
 
 function DAG:updateOutput(input)
+   self:sort()
+
    if #self.inputModules == 1 then
       self.inputModules[1]:updateOutput(input)
    else
@@ -120,3 +131,9 @@ function DAG:updateOutput(input)
 
    return self.output
 end
+
+function DAG:updateGradInput(input, gradOutput)
+   self:sort()
+end
+
+return DAG