Wow, seems to work (!)
authorFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 14:26:41 +0000 (15:26 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 14:26:41 +0000 (15:26 +0100)
dagnn.lua
test-dagnn.lua

index 0b8f7d4..05672e9 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -10,21 +10,21 @@ function DAG:__init()
    self.node = { }
 end
 
-function DAG:createNode(n)
-   if not self.node[n] then
-      self:add(n) -- Add it to the object as a Container
-      self.node[n] = {}
-      self.node[n].succ = {}
-      self.node[n].pred = {}
+function DAG:createNode(nnm)
+   if not self.node[nnm] then
+      self:add(nnm) -- Add it to the object as a Container
+      self.node[nnm] = {}
+      self.node[nnm].succ = {}
+      self.node[nnm].pred = {}
    end
 end
 
-function DAG:addEdge(a, b)
+function DAG:addEdge(nnma, nnmb)
    self.sorted = nil
-   self:createNode(a)
-   self:createNode(b)
-   table.insert(self.node[b].pred, a)
-   table.insert(self.node[a].succ, b)
+   self:createNode(nnma)
+   self:createNode(nnmb)
+   table.insert(self.node[nnmb].pred, nnma)
+   table.insert(self.node[nnma].succ, nnmb)
 end
 
 -- Apply f on t recursively; use the corresponding a1 and a2 elements
@@ -47,11 +47,11 @@ function DAG:setInput(i)
    self.sorted = nil
    self.inputModules = i
    self:nestApply(
-      function(m)
-         if #self.node[m].succ == 0 then
+      function(nnm)
+         if #self.node[nnm].succ == 0 then
             error('Input modules must have outgoing  edges.')
          end
-         if #self.node[m].pred > 0 then
+         if #self.node[nnm].pred > 0 then
             error('Input modules cannog have incoming edges.')
          end
       end,
@@ -63,11 +63,11 @@ function DAG:setOutput(o)
    self.sorted = nil
    self.outputModules = o
    self:nestApply(
-      function(m)
-         if #self.node[m].pred == 0 then
+      function(nnm)
+         if #self.node[nnm].pred == 0 then
             error('Output module must have incoming edges.')
          end
-         if #self.node[m].succ > 0 then
+         if #self.node[nnm].succ > 0 then
             error('Output module cannot have outgoing edges.')
          end
       end,
@@ -90,10 +90,10 @@ function DAG:putInOrder()
 
    repeat
       nc = 0
-      for i, node in pairs(self.node) do
-         for _, j in pairs(node.succ) do
-            if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then
-               distance[j] = distance[i] + 1
+      for nnma, node in pairs(self.node) do
+         for _, nnmb in pairs(node.succ) do
+            if distance[nnma] and (not distance[nnmb] or distance[nnmb] < distance[nnma] + 1) then
+               distance[nnmb] = distance[nnma] + 1
                nc = nc + 1
             end
          end
@@ -101,13 +101,13 @@ function DAG:putInOrder()
    until nc == 0
 
    self.sorted = { }
-   for n, d in pairs(distance) do
-      table.insert(self.sorted, { distance = d, node = n })
+   for m, d in pairs(distance) do
+      table.insert(self.sorted, { distance = d, nnm = m })
    end
 
    table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
 
-   for i, a in ipairs(self.sorted) do self.sorted[i] = a.node end
+   for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
 end
 
 function DAG:print()
@@ -121,21 +121,29 @@ end
 function DAG:updateOutput(input)
    self:putInOrder()
 
-   self:nestApply(function(m, i) m:updateOutput(i) end, self.inputModules, input)
+   self:nestApply(
+      function(nnm, i)
+         self.node[nnm].input = i
+         nnm:updateOutput(i)
+      end,
+      self.inputModules,
+      input
+   )
 
-   for _, m in ipairs(self.sorted) do
-      if #self.node[m].pred > 0 then
+   for _, nnm in ipairs(self.sorted) do
+      local node = self.node[nnm]
+      if #node.pred > 0 then
          local i
-         if #self.node[m].pred == 1 then
-            i = self.node[m].pred[1].output
-         elseif #self.node[m].pred > 1 then
+         if #node.pred == 1 then
+            i = node.pred[1].output
+         elseif #node.pred > 1 then
             i = {}
-            for k = 1, #self.node[m].pred do
-               i[k] = self.node[m].pred[k].output
+            for k = 1, #node.pred do
+               i[k] = node.pred[k].output
             end
          end
-         self.node[m].input = i
-         m:updateOutput(i)
+         node.input = i
+         nnm:updateOutput(i)
       end
    end
 
@@ -148,7 +156,7 @@ function DAG:updateGradInput(input, gradOutput)
    self:putInOrder()
 
    self:nestApply(
-      function(m, go) m:updateGradInput(self.node[m].input, go) end,
+      function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
       self.outputModules, gradOutput
    )
 
@@ -157,31 +165,36 @@ function DAG:updateGradInput(input, gradOutput)
    end
 
    for k = #self.sorted, 1, -1 do
-      local m = self.sorted[k]
-      local node = self.node[m]
+      local nnm = self.sorted[k]
+      local node = self.node[nnm]
       local pred, succ, gradInputSucc = node.pred, node.succ, node.gradInputSucc
 
-      -- We update m:gradInput
-      if #gradInputSucc == 1 then
-         m:updateGradInput(node.input, gradInputSucc[1])
-      elseif #gradInputSucc > 1 then
-         local sum
-         for k = 1, #succ do
-            if sum then
-               sum:add(succ[k].gradInput)
-            else
-               sum = succ[k].gradInput
+      if #gradInputSucc > 0 then
+         -- We update nnm:gradInput
+         local gi
+         if #gradInputSucc == 1 then
+            gi = gradInputSucc[1] -- we avoid a clone()
+         elseif #gradInputSucc > 1 then
+            for k = 1, #gradInputSucc do
+               if gi then
+                  gi:add(gradInputSucc[k])
+               else
+                  gi = gradInputSucc[k]:clone()
+               end
             end
          end
-         m:updateGradInput(node.input, sum)
+         nnm:updateGradInput(node.input, gi)
       end
 
       -- We fill the gradInputSucc of our predecessors
       if #pred == 1 then
-         table.insert(self.node[pred[1]].gradInputSucc, node.gradInput)
+         table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
       elseif #pred > 1 then
+         if not torch.type(nnm.gradInput) == 'table' then
+            error('Should have a table gradInput since it has multiple predecessors')
+         end
          for n = 1, #pred do
-            table.insert(self.node[node.pred[n]].gradInputSucc, m.gradInput[n])
+            table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
          end
       end
    end
index 0c9fe6d..32eed57 100755 (executable)
@@ -42,14 +42,14 @@ g:addEdge(b, c)
 g:addEdge(b, d)
 g:addEdge(d, f)
 
-g:setInput({ a })
+g:setInput({{a}})
 g:setOutput({ e, f })
 
 g:print()
 
 input = torch.Tensor(3, 10):uniform()
 
-output = g:updateOutput({ input })
+output = g:updateOutput({{ input }})
 
 printTensorTable(output)