Update.
[dagnn.git] / test-dagnn.lua
1 #!/usr/bin/env luajit
2
3 require 'torch'
4 require 'nn'
5
6 require 'dagnn'
7
8 a = nn.Linear(10, 10)
9 b = nn.ReLU()
10 c = nn.Linear(10, 3)
11 d = nn.Linear(10, 3)
12 e = nn.CMulTable()
13 f = nn.Linear(3, 2)
14
15 --[[
16
17    a -----> b ---> c ----> e ---
18              \           /
19               \--> d ---/
20                     \
21                      \---> f ---
22 ]]--
23
24 g = DAG:new()
25
26 g:setInput(a)
27 g:setOutput({ e, f })
28 g:addEdge(c, e)
29 g:addEdge(a, b)
30 g:addEdge(d, e)
31 g:addEdge(b, c)
32 g:addEdge(b, d)
33 g:addEdge(d, f)
34
35 g:order()
36
37 g:print(graph)
38
39 input = torch.Tensor(3, 10):uniform()
40
41 output = g:updateOutput(input)
42
43 print(output[1])
44 print(output[2])