Update.
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    -- Nodes are indexed by the module they encompass
10    self.node = { }
11 end
12
13 function DAG:createNode(nnm)
14    if not self.node[nnm] then
15       self:add(nnm) -- Add it to the object as a Container
16       self.node[nnm] = {}
17       self.node[nnm].succ = {}
18       self.node[nnm].pred = {}
19    end
20 end
21
22 function DAG:addEdge(nnma, nnmb)
23    self.sorted = nil
24    self:createNode(nnma)
25    self:createNode(nnmb)
26    table.insert(self.node[nnmb].pred, nnma)
27    table.insert(self.node[nnma].succ, nnmb)
28 end
29
30 -- Apply f on t recursively; use the corresponding a1 and a2 elements
31 -- (i.e. same keys) as second and third parameters to f when
32 -- available; return the results from f, organized in a similarly
33 -- nested table.
34 function DAG:nestApply(f, t, a1, a2)
35    if torch.type(t) == 'table' then
36       local result = {}
37       for k, s in pairs(t) do
38          result[k] = self:nestApply(f, s, a1 and a1[k], a2 and a2[k])
39       end
40       return result
41    else
42       return f(t, a1, a2)
43    end
44 end
45
46 function DAG:setInput(i)
47    self.sorted = nil
48    self.inputModules = i
49    self:nestApply(
50       function(nnm)
51          if #self.node[nnm].succ == 0 then
52             error('Input modules must have outgoing  edges.')
53          end
54          if #self.node[nnm].pred > 0 then
55             error('Input modules cannog have incoming edges.')
56          end
57       end,
58       self.inputModules
59    )
60 end
61
62 function DAG:setOutput(o)
63    self.sorted = nil
64    self.outputModules = o
65    self:nestApply(
66       function(nnm)
67          if #self.node[nnm].pred == 0 then
68             error('Output module must have incoming edges.')
69          end
70          if #self.node[nnm].succ > 0 then
71             error('Output module cannot have outgoing edges.')
72          end
73       end,
74       self.outputModules
75    )
76 end
77
78 function DAG:putInOrder()
79    if self.sorted then
80       return
81    end
82
83    -- First, we sort the nodes according to the DAG order
84
85    local distance = {}
86
87    self:nestApply(function(m) distance[m] = 1 end, self.inputModules)
88
89    local nc
90
91    repeat
92       nc = 0
93       for nnma, node in pairs(self.node) do
94          for _, nnmb in pairs(node.succ) do
95             if distance[nnma] and (not distance[nnmb] or distance[nnmb] < distance[nnma] + 1) then
96                distance[nnmb] = distance[nnma] + 1
97                nc = nc + 1
98             end
99          end
100       end
101    until nc == 0
102
103    self.sorted = { }
104    for m, d in pairs(distance) do
105       table.insert(self.sorted, { distance = d, nnm = m })
106    end
107
108    table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
109
110    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
111 end
112
113 function DAG:print()
114    self:putInOrder()
115
116    for i, d in ipairs(self.sorted) do
117       print('#' .. i .. ' -> ' .. torch.type(d))
118    end
119 end
120
121 function DAG:updateOutput(input)
122    self:putInOrder()
123
124    self:nestApply(
125       function(nnm, i)
126          self.node[nnm].input = i
127          nnm:updateOutput(i)
128       end,
129       self.inputModules,
130       input
131    )
132
133    for _, nnm in ipairs(self.sorted) do
134       local node = self.node[nnm]
135       if #node.pred > 0 then
136          local i
137          if #node.pred == 1 then
138             i = node.pred[1].output
139          elseif #node.pred > 1 then
140             i = {}
141             for k = 1, #node.pred do
142                i[k] = node.pred[k].output
143             end
144          end
145          node.input = i
146          nnm:updateOutput(i)
147       end
148    end
149
150    self.output = self:nestApply(function(m) return m.output end, self.outputModules)
151
152    return self.output
153 end
154
155 function DAG:computeGradInput(gradInputSucc)
156    local gi
157    if #gradInputSucc == 1 then
158       gi = gradInputSucc[1] -- we avoid a clone()
159    elseif #gradInputSucc > 1 then
160       for k = 1, #gradInputSucc do
161          if gi then
162             gi:add(gradInputSucc[k])
163          else
164             gi = gradInputSucc[k]:clone()
165          end
166       end
167    end
168    return gi
169 end
170
171 function DAG:updateGradInput(input, gradOutput)
172    self:putInOrder()
173
174    self:nestApply(
175       function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
176       self.outputModules, gradOutput
177    )
178
179    self:nestApply(
180       function(nnm, i) self.node[nnm].input = i end,
181       self.inputModules, input
182    )
183
184    for _, node in pairs(self.node) do
185       node.gradInputSucc = {}
186    end
187
188    for k = #self.sorted, 1, -1 do
189       local nnm = self.sorted[k]
190       local node = self.node[nnm]
191       local pred, gradInputSucc = node.pred, node.gradInputSucc
192
193       if #gradInputSucc > 0 then
194          nnm:updateGradInput(node.input, self:computeGradInput(gradInputSucc))
195       end
196
197       -- We fill the gradInputSucc of our predecessors
198       if #pred == 1 then
199          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
200       elseif #pred > 1 then
201          if not torch.type(nnm.gradInput) == 'table' then
202             error('Should have a table gradInput since it has multiple predecessors')
203          end
204          for n = 1, #pred do
205             table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
206          end
207       end
208    end
209
210    self.gradInput = self:nestApply(function(m) return m.gradInput end, self.inputModules)
211
212    return self.gradInput
213 end
214
215 return DAG