X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=test-dagnn.lua;h=a45d6365d61a6183b7b1b49758cddb89272d7708;hp=262ea6fe3111830ab1f8270118b608725e124881;hb=682b76200f755f5f16477e086056a86cafdea1cd;hpb=452781856eafd237579e5c90b6e345354df91b42 diff --git a/test-dagnn.lua b/test-dagnn.lua index 262ea6f..a45d636 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,10 @@ require 'nn' require 'dagnn' +-- torch.setnumthreads(params.nbThreads) +torch.setdefaulttensortype('torch.DoubleTensor') +torch.manualSeed(2) + a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) @@ -24,14 +28,16 @@ f = nn.Linear(3, 2) g = nn.DAG:new() g:setInput(a) -g:setOutput({ e, f }) +g:setOutput({ e }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) -g:addEdge(d, f) +-- g:addEdge(d, f) + +-- g = torch.load('dag.t7') g:print() @@ -39,5 +45,12 @@ input = torch.Tensor(3, 10):uniform() output = g:updateOutput(input) -print(output[1]) -print(output[2]) +if torch.type(output) == 'table' then + for i, t in pairs(output) do + print(tostring(i) .. ' -> ' .. tostring(t)) + end +else + print(tostring(output)) +end + +torch.save('dag.t7', g)