#!/usr/bin/env luajit require 'torch' require 'nn' require 'dagnn' a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) --[[ a -----> b ---> c ----> e --- \ / \--> d ---/ \ \---> f --- ]]-- g = nn.DAG:new() g:setInput(a) g:setOutput({ e, f }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) g:print() input = torch.Tensor(3, 10):uniform() output = g:updateOutput(input) print(output[1]) print(output[2])