function Graph:setInput(i)
if torch.type(i) == 'table' then
- self.input = i
+ self.inputModules = i
for _, m in ipairs(i) do
if not self.pred[m] and not self.succ[m] then
self:add(m)
function Graph:setOutput(o)
if torch.type(o) == 'table' then
- self.output = o
+ self.outputModules = o
for _, m in ipairs(o) do
if not self.pred[m] and not self.succ[m] then
self:add(m)
function Graph:order()
local distance = {}
- for _, a in pairs(self.input) do
+ for _, a in pairs(self.inputModules) do
distance[a] = 1
end
end
function Graph:updateOutput(input)
- return self.output.output
+ if #self.inputModules == 1 then
+ self.inputModules[1]:updateOutput(input)
+ else
+ for i, d in ipairs(self.inputModules) do
+ d:updateOutput(input[i])
+ end
+ end
+
+ for _, d in ipairs(self.sorted) do
+ if self.pred[d] then
+ if #self.pred[d] == 1 then
+ d:updateOutput(self.pred[d][1].output)
+ elseif #self.pred[d] > 1 then
+ local c = {}
+ for k = 1, #self.pred[d] do
+ c[k] = self.pred[d][k].output
+ end
+ d:updateOutput(c)
+ end
+ end
+ end
+
+ if #self.outputModules == 1 then
+ self.output = self.outputModules[1].output
+ else
+ self.output = { }
+ for i, d in ipairs(self.outputModules) do
+ self.output[i] = d.output
+ end
+ end
+
+ return self.output
end
----------------------------------------------------------------------
c = nn.Linear(10, 3)
d = nn.Linear(10, 3)
e = nn.CMulTable()
+f = nn.Linear(3, 2)
--[[
g = Graph:new()
g:setInput(a)
-g:setOutput(e)
+g:setOutput({ e, f })
g:addEdge(c, e)
g:addEdge(a, b)
g:addEdge(d, e)
g:addEdge(b, c)
g:addEdge(b, d)
+g:addEdge(d, f)
g:order()
+
g:print(graph)
+
+input = torch.Tensor(3, 10):uniform()
+
+output = g:updateOutput(input)
+
+print(output[1])
+print(output[2])