Update.
[dagnn.git] / graphnn.lua
index a3ee1c1..1003500 100755 (executable)
@@ -31,7 +31,7 @@ end
 
 function Graph:setInput(i)
    if torch.type(i) == 'table' then
-      self.input = i
+      self.inputModules = i
       for _, m in ipairs(i) do
          if not self.pred[m] and not self.succ[m] then
             self:add(m)
@@ -44,7 +44,7 @@ end
 
 function Graph:setOutput(o)
    if torch.type(o) == 'table' then
-      self.output = o
+      self.outputModules = o
       for _, m in ipairs(o) do
          if not self.pred[m] and not self.succ[m] then
             self:add(m)
@@ -58,7 +58,7 @@ end
 function Graph:order()
    local distance = {}
 
-   for _, a in pairs(self.input) do
+   for _, a in pairs(self.inputModules) do
       distance[a] = 1
    end
 
@@ -92,7 +92,38 @@ function Graph:print()
 end
 
 function Graph:updateOutput(input)
-   return self.output.output
+   if #self.inputModules == 1 then
+      self.inputModules[1]:updateOutput(input)
+   else
+      for i, d in ipairs(self.inputModules) do
+         d:updateOutput(input[i])
+      end
+   end
+
+   for _, d in ipairs(self.sorted) do
+      if self.pred[d] then
+         if #self.pred[d] == 1 then
+            d:updateOutput(self.pred[d][1].output)
+         elseif #self.pred[d] > 1 then
+            local c = {}
+            for k = 1, #self.pred[d] do
+               c[k] = self.pred[d][k].output
+            end
+            d:updateOutput(c)
+         end
+      end
+   end
+
+   if #self.outputModules == 1 then
+      self.output = self.outputModules[1].output
+   else
+      self.output = { }
+      for i, d in ipairs(self.outputModules) do
+         self.output[i] = d.output
+      end
+   end
+
+   return self.output
 end
 
 ----------------------------------------------------------------------
@@ -102,6 +133,7 @@ b = nn.ReLU()
 c = nn.Linear(10, 3)
 d = nn.Linear(10, 3)
 e = nn.CMulTable()
+f = nn.Linear(3, 2)
 
 --[[
 
@@ -114,12 +146,21 @@ e = nn.CMulTable()
 g = Graph:new()
 
 g:setInput(a)
-g:setOutput(e)
+g:setOutput({ e, f })
 g:addEdge(c, e)
 g:addEdge(a, b)
 g:addEdge(d, e)
 g:addEdge(b, c)
 g:addEdge(b, d)
+g:addEdge(d, f)
 
 g:order()
+
 g:print(graph)
+
+input = torch.Tensor(3, 10):uniform()
+
+output = g:updateOutput(input)
+
+print(output[1])
+print(output[2])