#!/usr/bin/env luajit require 'torch' require 'nn' require 'dagnn' function printTensorTable(t) if torch.type(t) == 'table' then for i, t in pairs(t) do print('-- ELEMENT [' .. i .. '] --') printTensorTable(t) end else print(tostring(t)) end end -- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') torch.manualSeed(2) a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) -- a -----> b ---> c ----> e --- -- \ / -- \--> d ---/ -- \ -- \---> f --- g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) g:setInput({{a}}) g:setOutput({ e, f }) g:print() input = torch.Tensor(3, 10):uniform() output = g:updateOutput({{ input }}) printTensorTable(output) ---------------------------------------------------------------------- print('******************************************************************') print('** updateGradInput ***********************************************') print('******************************************************************') gradInput = g:updateGradInput({{input}}, output) printTensorTable(gradInput)