Update.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 11 Jan 2017 07:12:22 +0000 (08:12 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 11 Jan 2017 07:12:22 +0000 (08:12 +0100)
dagnn.lua
test-dagnn.lua

index 52913ad..4841843 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -11,6 +11,7 @@ function DAG:__init()
 end
 
 function DAG:addEdge(a, b)
+   self.sorted = nil
    local pred, succ = self.pred, self.succ
    if not pred[a] and not succ[a] then
       self:add(a)
@@ -25,6 +26,7 @@ function DAG:addEdge(a, b)
 end
 
 function DAG:setInput(i)
+   self.sorted = nil
    if torch.type(i) == 'table' then
       self.inputModules = i
       for _, m in ipairs(i) do
@@ -38,6 +40,7 @@ function DAG:setInput(i)
 end
 
 function DAG:setOutput(o)
+   self.sorted = nil
    if torch.type(o) == 'table' then
       self.outputModules = o
       for _, m in ipairs(o) do
@@ -50,7 +53,11 @@ function DAG:setOutput(o)
    end
 end
 
-function DAG:order()
+function DAG:sort()
+   if self.sorted then
+      return
+   end
+
    local distance = {}
 
    for _, a in pairs(self.inputModules) do
@@ -81,12 +88,16 @@ function DAG:order()
 end
 
 function DAG:print()
+   self:sort()
+
    for i, d in ipairs(self.sorted) do
       print('#' .. i .. ' -> ' .. torch.type(d))
    end
 end
 
 function DAG:updateOutput(input)
+   self:sort()
+
    if #self.inputModules == 1 then
       self.inputModules[1]:updateOutput(input)
    else
@@ -120,3 +131,9 @@ function DAG:updateOutput(input)
 
    return self.output
 end
+
+function DAG:updateGradInput(input, gradOutput)
+   self:sort()
+end
+
+return DAG
index a0a81ab..262ea6f 100755 (executable)
@@ -21,10 +21,11 @@ f = nn.Linear(3, 2)
                      \---> f ---
 ]]--
 
-g = DAG:new()
+g = nn.DAG:new()
 
 g:setInput(a)
 g:setOutput({ e, f })
+
 g:addEdge(c, e)
 g:addEdge(a, b)
 g:addEdge(d, e)
@@ -32,9 +33,7 @@ g:addEdge(b, c)
 g:addEdge(b, d)
 g:addEdge(d, f)
 
-g:order()
-
-g:print(graph)
+g:print()
 
 input = torch.Tensor(3, 10):uniform()