Replaced error() with assert().
authorFrancois Fleuret <francois@fleuret.org>
Sat, 14 Jan 2017 21:14:34 +0000 (22:14 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Sat, 14 Jan 2017 21:14:34 +0000 (22:14 +0100)
dagnn.lua

index 67e113d..cf45233 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -146,12 +146,8 @@ function DAG:setInput(i)
    self.inputModules = i
    self:nestedApply(
       function(nnm)
-         if #self.node[nnm].succ == 0 then
-            error('Input modules must have outgoing edges.')
-         end
-         if #self.node[nnm].pred > 0 then
-            error('Input modules cannot have incoming edges.')
-         end
+         assert(#self.node[nnm].succ > 0, 'Input modules must have outgoing edges.')
+         assert(#self.node[nnm].pred == 0, 'Input modules cannot have incoming edges.')
       end,
       self.inputModules
    )
@@ -162,12 +158,8 @@ function DAG:setOutput(o)
    self.outputModules = o
    self:nestedApply(
       function(nnm)
-         if #self.node[nnm].pred == 0 then
-            error('Output module must have incoming edges.')
-         end
-         if #self.node[nnm].succ > 0 then
-            error('Output module cannot have outgoing edges.')
-         end
+         assert(#self.node[nnm].pred > 0, 'Output module must have incoming edges.')
+         assert(#self.node[nnm].succ == 0, 'Output module cannot have outgoing edges.')
       end,
       self.outputModules
    )
@@ -296,9 +288,8 @@ function DAG:updateGradInput(input, gradOutput)
       if #pred == 1 then
          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
       elseif #pred > 1 then
-         if not torch.type(nnm.gradInput) == 'table' then
-            error('Should have a table gradInput since it has multiple predecessors')
-         end
+         assert(torch.type(nnm.gradInput) == 'table',
+                'Should have a table gradInput since it has multiple predecessors')
          for n = 1, #pred do
             table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
          end