projects
/
dagnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
OCD cosmetics.
[dagnn.git]
/
dagnn.lua
diff --git
a/dagnn.lua
b/dagnn.lua
index
ca26926
..
b82398c
100755
(executable)
--- a/
dagnn.lua
+++ b/
dagnn.lua
@@
-69,6
+69,7
@@
function DAG:putInOrder()
local nc
local nl = 0
repeat
local nc
local nl = 0
repeat
+ assert(nl < #self.modules, 'Cycle detected in the graph.')
nc = 0
for nnma, node in pairs(self.node) do
for _, nnmb in pairs(node.succ) do
nc = 0
for nnma, node in pairs(self.node) do
for _, nnmb in pairs(node.succ) do
@@
-78,12
+79,11
@@
function DAG:putInOrder()
end
end
end
end
end
end
- assert(nl < #self.modules, 'Cycle detected in the graph.')
nl = nl + 1
until nc == 0
for _, nnm in pairs(self.modules) do
nl = nl + 1
until nc == 0
for _, nnm in pairs(self.modules) do
- assert(distance[nnm], 'Some modules are not connected to inputs')
+ assert(distance[nnm], 'Some modules are not connected to inputs
.
')
end
self.sorted = {}
end
self.sorted = {}
@@
-148,6
+148,10
@@
function DAG:connect(...)
end
end
end
end
+function DAG:setLabel(nnm, label)
+ self.node[nnm].label = label
+end
+
function DAG:setInput(i)
self.sorted = nil
self.inputModules = i
function DAG:setInput(i)
self.sorted = nil
self.inputModules = i
@@
-176,7
+180,11
@@
function DAG:print()
self:putInOrder()
for i, d in ipairs(self.sorted) do
self:putInOrder()
for i, d in ipairs(self.sorted) do
- print('#' .. i .. ' -> ' .. torch.type(d))
+ local decoration = ''
+ if self.node[d].label then
+ decoration = ' [' .. self.node[d].label .. ']'
+ end
+ print('#' .. i .. ' -> ' .. torch.type(d) .. decoration)
end
end
end
end
@@
-211,7
+219,7
@@
function DAG:saveDot(filename)
file:write(
' '
.. node.index
file:write(
' '
.. node.index
- .. ' [shape=box,label=\"' ..
torch.type(nnmb
) .. '\"]'
+ .. ' [shape=box,label=\"' ..
(self.node[nnmb].label or torch.type(nnmb)
) .. '\"]'
.. '\n'
)
.. '\n'
)
@@
-280,7
+288,7
@@
function DAG:updateOutput(input)
end
function DAG:updateGradInput(input, gradOutput)
end
function DAG:updateGradInput(input, gradOutput)
- assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput')
+ assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput
.
')
self:nestedApply(
function(nnm, go)
self:nestedApply(
function(nnm, go)
@@
-315,7
+323,7
@@
function DAG:updateGradInput(input, gradOutput)
table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
elseif #pred > 1 then
assert(torch.type(nnm.gradInput) == 'table',
table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
elseif #pred > 1 then
assert(torch.type(nnm.gradInput) == 'table',
- 'Should have a table gradInput since it has multiple predecessors')
+ 'Should have a table gradInput since it has multiple predecessors
.
')
for n = 1, #pred do
table.insert(self.node[pred[n]].gradInputSucc, nnm.gradInput[n])
end
for n = 1, #pred do
table.insert(self.node[pred[n]].gradInputSucc, nnm.gradInput[n])
end
@@
-331,7
+339,7
@@
function DAG:updateGradInput(input, gradOutput)
end
function DAG:accGradParameters(input, gradOutput, scale)
end
function DAG:accGradParameters(input, gradOutput, scale)
- assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters')
+ assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters
.
')
self:nestedApply(
function(nnm, go) self.node[nnm].gradOutput = go end,
self:nestedApply(
function(nnm, go) self.node[nnm].gradOutput = go end,