The analytic gradient checks out.
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    -- Nodes are indexed by the module they contain
10    self.node = { }
11 end
12
13 function DAG:createNode(nnm)
14    if not self.node[nnm] then
15       self:add(nnm) -- Add it to the object as a Container
16       self.node[nnm] = {}
17       self.node[nnm].succ = {}
18       self.node[nnm].pred = {}
19    end
20 end
21
22 function DAG:addEdge(nnma, nnmb)
23    self.sorted = nil
24    self:createNode(nnma)
25    self:createNode(nnmb)
26    table.insert(self.node[nnmb].pred, nnma)
27    table.insert(self.node[nnma].succ, nnmb)
28 end
29
30 -- Apply f on t recursively; use the corresponding element from args
31 -- (i.e. same keys) as second parameter to f when available; return
32 -- the results from f, organized in a similarly nested table.
33 function DAG:nestedApply(f, t, args)
34    if torch.type(t) == 'table' then
35       local result = {}
36       for k, s in pairs(t) do
37          result[k] = self:nestedApply(f, s, args and args[k])
38       end
39       return result
40    else
41       return f(t, args)
42    end
43 end
44
45 function DAG:setInput(i)
46    self.sorted = nil
47    self.inputModules = i
48    self:nestedApply(
49       function(nnm)
50          if #self.node[nnm].succ == 0 then
51             error('Input modules must have outgoing  edges.')
52          end
53          if #self.node[nnm].pred > 0 then
54             error('Input modules cannog have incoming edges.')
55          end
56       end,
57       self.inputModules
58    )
59 end
60
61 function DAG:setOutput(o)
62    self.sorted = nil
63    self.outputModules = o
64    self:nestedApply(
65       function(nnm)
66          if #self.node[nnm].pred == 0 then
67             error('Output module must have incoming edges.')
68          end
69          if #self.node[nnm].succ > 0 then
70             error('Output module cannot have outgoing edges.')
71          end
72       end,
73       self.outputModules
74    )
75 end
76
77 function DAG:putInOrder()
78    if self.sorted then
79       return
80    end
81
82    -- First, we sort the nodes according to the DAG order
83
84    local distance = {}
85
86    self:nestedApply(function(m) distance[m] = 1 end, self.inputModules)
87
88    local nc
89
90    repeat
91       nc = 0
92       for nnma, node in pairs(self.node) do
93          for _, nnmb in pairs(node.succ) do
94             if distance[nnma] and (not distance[nnmb] or distance[nnmb] < distance[nnma] + 1) then
95                distance[nnmb] = distance[nnma] + 1
96                nc = nc + 1
97             end
98          end
99       end
100    until nc == 0
101
102    self.sorted = { }
103    for m, d in pairs(distance) do
104       table.insert(self.sorted, { distance = d, nnm = m })
105    end
106
107    table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
108
109    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
110 end
111
112 function DAG:print()
113    self:putInOrder()
114
115    for i, d in ipairs(self.sorted) do
116       print('#' .. i .. ' -> ' .. torch.type(d))
117    end
118 end
119
120 function DAG:updateOutput(input)
121    self:putInOrder()
122
123    self:nestedApply(
124       function(nnm, i)
125          self.node[nnm].input = i
126          nnm:updateOutput(i)
127       end,
128       self.inputModules,
129       input
130    )
131
132    for _, nnm in ipairs(self.sorted) do
133       local node = self.node[nnm]
134       if #node.pred > 0 then
135          local i
136          if #node.pred == 1 then
137             i = node.pred[1].output
138          elseif #node.pred > 1 then
139             i = {}
140             for k = 1, #node.pred do
141                i[k] = node.pred[k].output
142             end
143          end
144          node.input = i
145          nnm:updateOutput(i)
146       end
147    end
148
149    self.output = self:nestedApply(
150       function(m) return m.output end,
151       self.outputModules
152    )
153
154    return self.output
155 end
156
157 function DAG:computeGradInput(gradInputSucc)
158    local gi
159    if #gradInputSucc == 1 then
160       gi = gradInputSucc[1] -- we avoid a clone()
161    elseif #gradInputSucc > 1 then
162       for k = 1, #gradInputSucc do
163          if gi then
164             gi:add(gradInputSucc[k])
165          else
166             gi = gradInputSucc[k]:clone()
167          end
168       end
169    end
170    return gi
171 end
172
173 function DAG:updateGradInput(input, gradOutput)
174    self:putInOrder()
175
176    self:nestedApply(
177       function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
178       self.outputModules, gradOutput
179    )
180
181    self:nestedApply(
182       function(nnm, i) self.node[nnm].input = i end,
183       self.inputModules, input
184    )
185
186    for _, node in pairs(self.node) do
187       node.gradInputSucc = {}
188    end
189
190    for k = #self.sorted, 1, -1 do
191       local nnm = self.sorted[k]
192       local node = self.node[nnm]
193       local pred, gradInputSucc = node.pred, node.gradInputSucc
194
195       if #gradInputSucc > 0 then
196          nnm:updateGradInput(node.input, self:computeGradInput(gradInputSucc))
197       end
198
199       -- We fill the gradInputSucc of our predecessors
200       if #pred == 1 then
201          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
202       elseif #pred > 1 then
203          if not torch.type(nnm.gradInput) == 'table' then
204             error('Should have a table gradInput since it has multiple predecessors')
205          end
206          for n = 1, #pred do
207             table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
208          end
209       end
210    end
211
212    self.gradInput = self:nestedApply(function(m) return m.gradInput end, self.inputModules)
213
214    return self.gradInput
215 end
216
217 function DAG:accGradParameters(input, gradOutput, scale)
218    scale = scale or 1
219
220    self:putInOrder()
221
222    self:nestedApply(
223       function(nnm, go) nnm:updateGradInput(self.node[nnm].input, go) end,
224       self.outputModules, gradOutput
225    )
226
227    self:nestedApply(
228       function(nnm, i) self.node[nnm].input = i end,
229       self.inputModules, input
230    )
231
232    for k = #self.sorted, 1, -1 do
233       local nnm = self.sorted[k]
234       local node = self.node[nnm]
235       nnm:accGradParameters(node.input, self:computeGradInput(node.gradInputSucc), scale)
236    end
237 end
238
239 return DAG