Update.
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    self.pred = {}
10    self.succ = {}
11 end
12
13 function DAG:addEdge(a, b)
14    local pred, succ = self.pred, self.succ
15    if not pred[a] and not succ[a] then
16       self:add(a)
17    end
18    if not pred[b] and not succ[b] then
19       self:add(b)
20    end
21    pred[b] = pred[b] or {}
22    pred[b][#pred[b] + 1] = a
23    succ[a] = succ[a] or {}
24    succ[a][#succ[a] + 1] = b
25 end
26
27 function DAG:setInput(i)
28    if torch.type(i) == 'table' then
29       self.inputModules = i
30       for _, m in ipairs(i) do
31          if not self.pred[m] and not self.succ[m] then
32             self:add(m)
33          end
34       end
35    else
36       self:setInput({ i })
37    end
38 end
39
40 function DAG:setOutput(o)
41    if torch.type(o) == 'table' then
42       self.outputModules = o
43       for _, m in ipairs(o) do
44          if not self.pred[m] and not self.succ[m] then
45             self:add(m)
46          end
47       end
48    else
49       self:setOutput({ o })
50    end
51 end
52
53 function DAG:order()
54    local distance = {}
55
56    for _, a in pairs(self.inputModules) do
57       distance[a] = 1
58    end
59
60    local nc
61
62    repeat
63       nc = 0
64       for i, isucc in pairs(self.succ) do
65          for _, j in pairs(isucc) do
66             if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then
67                distance[j] = distance[i] + 1
68                nc = nc + 1
69             end
70          end
71       end
72    until nc == 0
73
74    self.sorted = { }
75    for i, d in pairs(distance) do
76       table.insert(self.sorted, { d, i })
77    end
78
79    table.sort(self.sorted, function(a, b) return a[1] < b[1] end)
80    for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end
81 end
82
83 function DAG:print()
84    for i, d in ipairs(self.sorted) do
85       print('#' .. i .. ' -> ' .. torch.type(d))
86    end
87 end
88
89 function DAG:updateOutput(input)
90    if #self.inputModules == 1 then
91       self.inputModules[1]:updateOutput(input)
92    else
93       for i, d in ipairs(self.inputModules) do
94          d:updateOutput(input[i])
95       end
96    end
97
98    for _, d in ipairs(self.sorted) do
99       if self.pred[d] then
100          if #self.pred[d] == 1 then
101             d:updateOutput(self.pred[d][1].output)
102          elseif #self.pred[d] > 1 then
103             local c = {}
104             for k = 1, #self.pred[d] do
105                c[k] = self.pred[d][k].output
106             end
107             d:updateOutput(c)
108          end
109       end
110    end
111
112    if #self.outputModules == 1 then
113       self.output = self.outputModules[1].output
114    else
115       self.output = { }
116       for i, d in ipairs(self.outputModules) do
117          self.output[i] = d.output
118       end
119    end
120
121    return self.output
122 end