projects
/
dagnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
91d2281
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Sat, 14 Jan 2017 16:04:06 +0000
(17:04 +0100)
committer
Francois Fleuret
<francois@fleuret.org>
Sat, 14 Jan 2017 16:04:06 +0000
(17:04 +0100)
dagnn.lua
patch
|
blob
|
history
diff --git
a/dagnn.lua
b/dagnn.lua
index
0073e39
..
5921c05
100755
(executable)
--- a/
dagnn.lua
+++ b/
dagnn.lua
@@
-86,14
+86,16
@@
function DAG:putInOrder()
for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
--- This accumulate x in a where they are both nested tables of
--- tensors. If first is true, set a = x.
+-- This accumulates x in a where they are both nested tables of
+-- tensors. If first is true, set a = x. Behavior is undefined if a
+-- and x do not have the exact same structure.
function DAG:nestedAccTensor(a, x, first)
if torch.type(x) == 'table' then
function DAG:nestedAccTensor(a, x, first)
if torch.type(x) == 'table' then
-
a = a or
{}
+
local b =
{}
for i in pairs(x) do
for i in pairs(x) do
-
a
[i] = self:nestedAccTensor(a[i], x[i], first)
+
b
[i] = self:nestedAccTensor(a[i], x[i], first)
end
end
+ a = b
else
if first then
if a then
else
if first then
if a then
@@
-222,8
+224,9
@@
function DAG:updateOutput(input)
self:nestedApply(
function(nnm, i)
self:nestedApply(
function(nnm, i)
- self.node[nnm].input = i
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+ local node = self.node[nnm]
+ node.input = i
+ self:rethrowErrors(nnm, node.index, 'updateOutput', i)
end,
self.inputModules,
input
end,
self.inputModules,
input
@@
-242,7
+245,7
@@
function DAG:updateOutput(input)
end
end
node.input = i
end
end
node.input = i
- self:rethrowErrors(nnm,
self.node[nnm]
.index, 'updateOutput', i)
+ self:rethrowErrors(nnm,
node
.index, 'updateOutput', i)
end
end
end
end
@@
-261,7
+264,7
@@
function DAG:updateGradInput(input, gradOutput)
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
- self:rethrowErrors(nnm, node.index, 'updateGradInput',
self.node[nnm]
.input, go)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput',
node
.input, go)
end,
self.outputModules, gradOutput
)
end,
self.outputModules, gradOutput
)
@@
-282,7
+285,7
@@
function DAG:updateGradInput(input, gradOutput)
if #node.gradInputSucc > 0 then
self:updateGradOutput(node)
if #node.gradInputSucc > 0 then
self:updateGradOutput(node)
- self:rethrowErrors(nnm,
self.node[nnm]
.index, 'updateGradInput', node.input, node.gradOutput)
+ self:rethrowErrors(nnm,
node
.index, 'updateGradInput', node.input, node.gradOutput)
end
-- We fill the gradInputSucc of our predecessors
end
-- We fill the gradInputSucc of our predecessors
@@
-304,8
+307,6
@@
function DAG:updateGradInput(input, gradOutput)
end
function DAG:accGradParameters(input, gradOutput, scale)
end
function DAG:accGradParameters(input, gradOutput, scale)
- scale = scale or 1
-
assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
self:nestedApply(
assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
self:nestedApply(