05672e9bad5997012063b7c4f76510306cb58181
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    -- Nodes are indexed by the module they encompass
10    self.node = { }
11 end
12
13 function DAG:createNode(nnm)
14    if not self.node[nnm] then
15       self:add(nnm) -- Add it to the object as a Container
16       self.node[nnm] = {}
17       self.node[nnm].succ = {}
18       self.node[nnm].pred = {}
19    end
20 end
21
22 function DAG:addEdge(nnma, nnmb)
23    self.sorted = nil
24    self:createNode(nnma)
25    self:createNode(nnmb)
26    table.insert(self.node[nnmb].pred, nnma)
27    table.insert(self.node[nnma].succ, nnmb)
28 end
29
30 -- Apply f on t recursively; use the corresponding a1 and a2 elements
31 -- (i.e. same keys) as second and third parameters to f when
32 -- available; return the results from f, organized in a similarly
33 -- nested table.
34 function DAG:nestApply(f, t, a1, a2)
35    if torch.type(t) == 'table' then
36       local result = {}
37       for k, s in pairs(t) do
38          result[k] = self:nestApply(f, s, a1 and a1[k], a2 and a2[k])
39       end
40       return result
41    else
42       return f(t, a1, a2)
43    end
44 end
45
46 function DAG:setInput(i)
47    self.sorted = nil
48    self.inputModules = i
49    self:nestApply(
50       function(nnm)
51          if #self.node[nnm].succ == 0 then
52             error('Input modules must have outgoing  edges.')
53          end
54          if #self.node[nnm].pred > 0 then
55             error('Input modules cannog have incoming edges.')
56          end
57       end,
58       self.inputModules
59    )
60 end
61
62 function DAG:setOutput(o)
63    self.sorted = nil
64    self.outputModules = o
65    self:nestApply(
66       function(nnm)
67          if #self.node[nnm].pred == 0 then
68             error('Output module must have incoming edges.')
69          end
70          if #self.node[nnm].succ > 0 then
71             error('Output module cannot have outgoing edges.')
72          end
73       end,
74       self.outputModules
75    )
76 end
77
78 function DAG:putInOrder()
79    if self.sorted then
80       return
81    end
82
83    -- First, we sort the nodes according to the DAG order
84
85    local distance = {}
86
87    self:nestApply(function(m) distance[m] = 1 end, self.inputModules)
88
89    local nc
90
91    repeat
92       nc = 0
93       for nnma, node in pairs(self.node) do
94          for _, nnmb in pairs(node.succ) do
95             if distance[nnma] and (not distance[nnmb] or distance[nnmb] < distance[nnma] + 1) then
96                distance[nnmb] = distance[nnma] + 1
97                nc = nc + 1
98             end
99          end
100       end
101    until nc == 0
102
103    self.sorted = { }
104    for m, d in pairs(distance) do
105       table.insert(self.sorted, { distance = d, nnm = m })
106    end
107
108    table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
109
110    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
111 end
112
113 function DAG:print()
114    self:putInOrder()
115
116    for i, d in ipairs(self.sorted) do
117       print('#' .. i .. ' -> ' .. torch.type(d))
118    end
119 end
120
121 function DAG:updateOutput(input)
122    self:putInOrder()
123
124    self:nestApply(
125       function(nnm, i)
126          self.node[nnm].input = i
127          nnm:updateOutput(i)
128       end,
129       self.inputModules,
130       input
131    )
132
133    for _, nnm in ipairs(self.sorted) do
134       local node = self.node[nnm]
135       if #node.pred > 0 then
136          local i
137          if #node.pred == 1 then
138             i = node.pred[1].output
139          elseif #node.pred > 1 then
140             i = {}
141             for k = 1, #node.pred do
142                i[k] = node.pred[k].output
143             end
144          end
145          node.input = i
146          nnm:updateOutput(i)
147       end
148    end
149
150    self.output = self:nestApply(function(m) return m.output end, self.outputModules)
151
152    return self.output
153 end
154
155 function DAG:updateGradInput(input, gradOutput)
156    self:putInOrder()
157
158    self:nestApply(
159       function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
160       self.outputModules, gradOutput
161    )
162
163    for _, node in pairs(self.node) do
164       node.gradInputSucc = {}
165    end
166
167    for k = #self.sorted, 1, -1 do
168       local nnm = self.sorted[k]
169       local node = self.node[nnm]
170       local pred, succ, gradInputSucc = node.pred, node.succ, node.gradInputSucc
171
172       if #gradInputSucc > 0 then
173          -- We update nnm:gradInput
174          local gi
175          if #gradInputSucc == 1 then
176             gi = gradInputSucc[1] -- we avoid a clone()
177          elseif #gradInputSucc > 1 then
178             for k = 1, #gradInputSucc do
179                if gi then
180                   gi:add(gradInputSucc[k])
181                else
182                   gi = gradInputSucc[k]:clone()
183                end
184             end
185          end
186          nnm:updateGradInput(node.input, gi)
187       end
188
189       -- We fill the gradInputSucc of our predecessors
190       if #pred == 1 then
191          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
192       elseif #pred > 1 then
193          if not torch.type(nnm.gradInput) == 'table' then
194             error('Should have a table gradInput since it has multiple predecessors')
195          end
196          for n = 1, #pred do
197             table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
198          end
199       end
200    end
201
202    self.gradInput = self:nestApply(function(m) return m.gradInput end, self.inputModules)
203
204    return self.gradInput
205 end
206
207 return DAG