X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=test-dagnn.lua;h=32eed5724f5ea84c8e55fe966b54e53021013902;hb=34ed0d49d9b6b03811cd92c9513edf4ec5d4d2d2;hp=262ea6fe3111830ab1f8270118b608725e124881;hpb=452781856eafd237579e5c90b6e345354df91b42;p=dagnn.git diff --git a/test-dagnn.lua b/test-dagnn.lua index 262ea6f..32eed57 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,21 @@ require 'nn' require 'dagnn' +function printTensorTable(t) + if torch.type(t) == 'table' then + for i, t in pairs(t) do + print('-- ELEMENT [' .. i .. '] --') + printTensorTable(t) + end + else + print(tostring(t)) + end +end + +-- torch.setnumthreads(params.nbThreads) +torch.setdefaulttensortype('torch.DoubleTensor') +torch.manualSeed(2) + a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) @@ -12,19 +27,13 @@ d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- +-- a -----> b ---> c ----> e --- +-- \ / +-- \--> d ---/ +-- \ +-- \---> f --- -g = nn.DAG:new() - -g:setInput(a) -g:setOutput({ e, f }) +g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) @@ -33,11 +42,22 @@ g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) +g:setInput({{a}}) +g:setOutput({ e, f }) + g:print() input = torch.Tensor(3, 10):uniform() -output = g:updateOutput(input) +output = g:updateOutput({{ input }}) + +printTensorTable(output) + +---------------------------------------------------------------------- + +print('******************************************************************') +print('** updateGradInput ***********************************************') +print('******************************************************************') +gradInput = g:updateGradInput({ input }, output) -print(output[1]) -print(output[2]) +printTensorTable(gradInput)