#!/usr/bin/env luajit require 'torch' require 'nn' require 'image' require 'optim' ---------------------------------------------------------------------- local Graph, parent = torch.class('nn.Graph', 'nn.Container') function Graph:__init() parent.__init(self) self.pred = {} self.succ = {} end function Graph:addEdge(a, b) local pred, succ = self.pred, self.succ if not pred[a] and not succ[a] then self:add(a) end if not pred[b] and not succ[b] then self:add(b) end pred[b] = pred[b] or {} pred[b][#pred[b] + 1] = a succ[a] = succ[a] or {} succ[a][#succ[a] + 1] = b end function Graph:setInput(i) if torch.type(i) == 'table' then self.inputModules = i for _, m in ipairs(i) do if not self.pred[m] and not self.succ[m] then self:add(m) end end else self:setInput({ i }) end end function Graph:setOutput(o) if torch.type(o) == 'table' then self.outputModules = o for _, m in ipairs(o) do if not self.pred[m] and not self.succ[m] then self:add(m) end end else self:setOutput({ o }) end end function Graph:order() local distance = {} for _, a in pairs(self.inputModules) do distance[a] = 1 end local nc repeat nc = 0 for i, isucc in pairs(self.succ) do for _, j in pairs(isucc) do if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then distance[j] = distance[i] + 1 nc = nc + 1 end end end until nc == 0 self.sorted = { } for i, d in pairs(distance) do table.insert(self.sorted, { d, i }) end table.sort(self.sorted, function(a, b) return a[1] < b[1] end) for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end end function Graph:print() for i, d in ipairs(self.sorted) do print('#' .. i .. ' -> ' .. torch.type(d)) end end function Graph:updateOutput(input) if #self.inputModules == 1 then self.inputModules[1]:updateOutput(input) else for i, d in ipairs(self.inputModules) do d:updateOutput(input[i]) end end for _, d in ipairs(self.sorted) do if self.pred[d] then if #self.pred[d] == 1 then d:updateOutput(self.pred[d][1].output) elseif #self.pred[d] > 1 then local c = {} for k = 1, #self.pred[d] do c[k] = self.pred[d][k].output end d:updateOutput(c) end end end if #self.outputModules == 1 then self.output = self.outputModules[1].output else self.output = { } for i, d in ipairs(self.outputModules) do self.output[i] = d.output end end return self.output end ---------------------------------------------------------------------- a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) --[[ a -----> b ---> c ---- e --- \ / \--> d ---/ ]]-- g = Graph:new() g:setInput(a) g:setOutput({ e, f }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) g:order() g:print(graph) input = torch.Tensor(3, 10):uniform() output = g:updateOutput(input) print(output[1]) print(output[2])