Fixed the initialization in gradOutput in accGradParameters + cosmetics.
[dagnn.git] / dagnn.lua
index 14cd582..0c1d153 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -29,7 +29,7 @@ function DAG:__init()
    self.node = { }
 end
 
--- Apply f on t recursively; use the corresponding element from args
+-- Apply f on t recursively; use the corresponding elements from args
 -- (i.e. same keys) as second parameter to f when available; return
 -- the results from f, organized in a similarly nested table.
 function DAG:nestedApply(f, t, args)
@@ -167,20 +167,26 @@ function DAG:saveDot(filename)
 
    file:write('\n')
 
-   for nnma, node in pairs(self.node) do
+   for nnmb, node in pairs(self.node) do
       file:write(
          '  '
             .. node.index
-            .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]'
+            .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]'
             .. '\n'
       )
 
-      for _, nnmb in pairs(node.succ) do
+      for i, nnma in pairs(node.pred) do
+         local decoration = ''
+         if #node.pred > 1 then
+            -- decoration = ' [headlabel=\"' .. i .. '\"]'
+            decoration = ' [label=\"' .. i .. '\"]'
+         end
          file:write(
             '  '
-               .. node.index
+               .. self.node[nnma].index
                .. ' -> '
                .. self.node[nnmb].index
+               .. decoration
                .. '\n'
          )
       end
@@ -234,12 +240,14 @@ function DAG:updateOutput(input)
 end
 
 function DAG:updateGradInput(input, gradOutput)
-   assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput')
+   assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput')
 
    self:nestedApply(
       function(nnm, go)
+         local node = self.node[nnm]
+         node.gradOutput = go
          -- nnm:updateGradInput(self.node[nnm].input, go)
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go)
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
       end,
       self.outputModules, gradOutput
    )
@@ -285,12 +293,22 @@ end
 function DAG:accGradParameters(input, gradOutput, scale)
    scale = scale or 1
 
-   assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters')
+   assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
+
+   self:nestedApply(
+      function(nnm, go) self.node[nnm].gradOutput = go end,
+      self.outputModules, gradOutput
+   )
+
+   self:nestedApply(
+      function(nnm, i) self.node[nnm].input = i end,
+      self.inputModules, input
+   )
 
    for k = 1, #self.modules do
       local nnm = self.modules[k]
       local node = self.node[nnm]
       -- nnm:accGradParameters(node.input, node.gradOutput, scale)
-      self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale)
+      self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
    end
 end