Update.
[dagnn.git] / dagnn.lua
1
2 --[[
3
4    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
5    Written by Francois Fleuret <francois.fleuret@idiap.ch>
6
7    This file is free software: you can redistribute it and/or modify
8    it under the terms of the GNU General Public License version 3 as
9    published by the Free Software Foundation.
10
11    It is distributed in the hope that it will be useful, but WITHOUT
12    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
13    or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
14    License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this file.  If not, see <http://www.gnu.org/licenses/>.
18
19 ]]--
20
21 require 'torch'
22 require 'nn'
23
24 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
25
26 function DAG:__init()
27    parent.__init(self)
28    -- Nodes are indexed by the module they contain
29    self.node = { }
30 end
31
32 function DAG:createNode(nnm)
33    if not self.node[nnm] then
34       self:add(nnm) -- Add it to the object as a Container
35       local node = {}
36       node.succ = {}
37       node.pred = {}
38       node.index = #self.modules
39       self.node[nnm] = node
40    end
41 end
42
43 function DAG:addEdge(nnma, nnmb)
44    self.sorted = nil
45    self:createNode(nnma)
46    self:createNode(nnmb)
47    table.insert(self.node[nnmb].pred, nnma)
48    table.insert(self.node[nnma].succ, nnmb)
49 end
50
51 -- Apply f on t recursively; use the corresponding element from args
52 -- (i.e. same keys) as second parameter to f when available; return
53 -- the results from f, organized in a similarly nested table.
54 function DAG:nestedApply(f, t, args)
55    if torch.type(t) == 'table' then
56       local result = {}
57       for k, s in pairs(t) do
58          result[k] = self:nestedApply(f, s, args and args[k])
59       end
60       return result
61    else
62       return f(t, args)
63    end
64 end
65
66 function DAG:setInput(i)
67    self.sorted = nil
68    self.inputModules = i
69    self:nestedApply(
70       function(nnm)
71          if #self.node[nnm].succ == 0 then
72             error('Input modules must have outgoing  edges.')
73          end
74          if #self.node[nnm].pred > 0 then
75             error('Input modules cannog have incoming edges.')
76          end
77       end,
78       self.inputModules
79    )
80 end
81
82 function DAG:setOutput(o)
83    self.sorted = nil
84    self.outputModules = o
85    self:nestedApply(
86       function(nnm)
87          if #self.node[nnm].pred == 0 then
88             error('Output module must have incoming edges.')
89          end
90          if #self.node[nnm].succ > 0 then
91             error('Output module cannot have outgoing edges.')
92          end
93       end,
94       self.outputModules
95    )
96 end
97
98 function DAG:putInOrder()
99    if self.sorted then
100       return
101    end
102
103    local distance = {}
104    self:nestedApply(function(m) distance[m] = 1 end, self.inputModules)
105
106    local nc
107    repeat
108       nc = 0
109       for nnma, node in pairs(self.node) do
110          for _, nnmb in pairs(node.succ) do
111             if distance[nnma] and (not distance[nnmb] or distance[nnmb] < distance[nnma] + 1) then
112                distance[nnmb] = distance[nnma] + 1
113                nc = nc + 1
114             end
115          end
116       end
117    until nc == 0
118
119    self.sorted = { }
120    for m, d in pairs(distance) do
121       table.insert(self.sorted, { distance = d, nnm = m })
122    end
123
124    table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
125
126    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
127 end
128
129 function DAG:computeGradOutput(gradInputSucc)
130    local gi
131    if #gradInputSucc == 1 then
132       gi = gradInputSucc[1] -- we avoid a clone()
133    elseif #gradInputSucc > 1 then
134       for k = 1, #gradInputSucc do
135          if gi then
136             gi:add(gradInputSucc[k])
137          else
138             gi = gradInputSucc[k]:clone()
139          end
140       end
141    end
142    return gi
143 end
144
145 function DAG:print()
146    self:putInOrder()
147
148    for i, d in ipairs(self.sorted) do
149       print('#' .. i .. ' -> ' .. torch.type(d))
150    end
151 end
152
153 ----------------------------------------------------------------------
154
155 function DAG:updateOutput(input)
156    self:putInOrder()
157
158    self:nestedApply(
159       function(nnm, i)
160          self.node[nnm].input = i
161          -- nnm:updateOutput(i)
162          self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
163       end,
164       self.inputModules,
165       input
166    )
167
168    for _, nnm in ipairs(self.sorted) do
169       local node = self.node[nnm]
170       if #node.pred > 0 then
171          local i
172          if #node.pred == 1 then
173             i = node.pred[1].output
174          elseif #node.pred > 1 then
175             i = {}
176             for k = 1, #node.pred do
177                i[k] = node.pred[k].output
178             end
179          end
180          node.input = i
181          -- nnm:updateOutput(i)
182          self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
183       end
184    end
185
186    self.output = self:nestedApply(
187       function(m) return m.output end,
188       self.outputModules
189    )
190
191    return self.output
192 end
193
194 function DAG:updateGradInput(input, gradOutput)
195    assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput')
196
197    self:nestedApply(
198       function(nnm, go)
199          -- nnm:updateGradInput(self.node[nnm].input, go)
200          self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go)
201       end,
202       self.outputModules, gradOutput
203    )
204
205    self:nestedApply(
206       function(nnm, i) self.node[nnm].input = i end,
207       self.inputModules, input
208    )
209
210    for _, node in pairs(self.node) do
211       node.gradInputSucc = {}
212    end
213
214    for k = #self.sorted, 1, -1 do
215       local nnm = self.sorted[k]
216       local node = self.node[nnm]
217       local pred, gradInputSucc = node.pred, node.gradInputSucc
218
219       if #gradInputSucc > 0 then
220          node.gradOutput = self:computeGradOutput(gradInputSucc)
221          -- nnm:updateGradInput(node.input, node.gradOutput)
222          self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
223       end
224
225       -- We fill the gradInputSucc of our predecessors
226       if #pred == 1 then
227          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
228       elseif #pred > 1 then
229          if not torch.type(nnm.gradInput) == 'table' then
230             error('Should have a table gradInput since it has multiple predecessors')
231          end
232          for n = 1, #pred do
233             table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
234          end
235       end
236    end
237
238    self.gradInput = self:nestedApply(function(m) return m.gradInput end, self.inputModules)
239
240    return self.gradInput
241 end
242
243 function DAG:accGradParameters(input, gradOutput, scale)
244    scale = scale or 1
245
246    assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters')
247
248    for k = 1, #self.modules do
249       local nnm = self.modules[k]
250       local node = self.node[nnm]
251       -- nnm:accGradParameters(node.input, node.gradOutput, scale)
252       self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale)
253    end
254 end
255
256 ----------------------------------------------------------------------