Made the checks a bit more demanding. Everything seems in order.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 21:58:43 +0000 (22:58 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 21:58:43 +0000 (22:58 +0100)
test-dagnn.lua

index f7de819..9f343a9 100755 (executable)
@@ -99,7 +99,11 @@ dag:connect(c, e)
 dag:setInput(a)
 dag:setOutput({ d, e })
 
 dag:setInput(a)
 dag:setOutput({ d, e })
 
--- We check it works when we put it into a nn.Sequential
+-- Check the output of the dot file
+print('Writing /tmp/graph.dot')
+dag:saveDot('/tmp/graph.dot')
+
+-- Let's make a model where the dag is inside another nn.Container.
 model = nn.Sequential()
    :add(nn.Linear(50, 50))
    :add(dag)
 model = nn.Sequential()
    :add(nn.Linear(50, 50))
    :add(dag)
@@ -109,7 +113,11 @@ local input = torch.Tensor(30, 50):uniform()
 local output = model:updateOutput(input):clone()
 output:uniform()
 
 local output = model:updateOutput(input):clone()
 output:uniform()
 
+-- Check that DAG:accGradParameters and friends work okay
 print('Gradient estimate error ' .. checkGrad(model, nn.MSECriterion(), input, output))
 
 print('Gradient estimate error ' .. checkGrad(model, nn.MSECriterion(), input, output))
 
-print('Writing /tmp/graph.dot')
-dag:saveDot('/tmp/graph.dot')
+-- Check that we can save and reload the model
+model:clearState()
+torch.save('/tmp/test.t7', model)
+local otherModel = torch.load('/tmp/test.t7')
+print('Gradient estimate error ' .. checkGrad(otherModel, nn.MSECriterion(), input, output))