X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=test-dagnn.lua;h=38019565baf775c90bf40b47256c09b1caacaa5b;hp=366e98f620b89b1f793284e3317489909a371ce4;hb=2de7b6da5330f15b0aef73cdff7cae472c25b037;hpb=9dad4fa1118632bfa02c01e4d6a8a5a129061a54 diff --git a/test-dagnn.lua b/test-dagnn.lua index 366e98f..3801956 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -23,9 +23,8 @@ require 'torch' require 'nn' require 'dagnn' --- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') -torch.manualSeed(2) +torch.manualSeed(1) function checkGrad(model, criterion, input, target) local params, gradParams = model:getParameters() @@ -92,10 +91,9 @@ c = nn.Linear(10, 15) d = nn.CMulTable() e = nn.CAddTable() -model:connect(a, b) +model:connect(a, b, c) model:connect(b, nn.Linear(10, 15), nn.ReLU(), d) model:connect(d, e) -model:connect(b, c) model:connect(c, d) model:connect(c, nn.Mul(-1), e) @@ -110,4 +108,4 @@ output:uniform() print('Error = ' .. checkGrad(model, nn.MSECriterion(), input, output)) print('Writing /tmp/graph.dot') -model:dot('/tmp/graph.dot') +model:saveDot('/tmp/graph.dot')