Update.
[dagnn.git] / test-dagnn.lua
1 #!/usr/bin/env luajit
2
3 require 'torch'
4 require 'nn'
5
6 require 'dagnn'
7
8 function printTensorTable(t)
9    if torch.type(t) == 'table' then
10       for i, t in pairs(t) do
11          print('-- ELEMENT [' .. i .. '] --')
12          printTensorTable(t)
13       end
14    else
15       print(tostring(t))
16    end
17 end
18
19 -- torch.setnumthreads(params.nbThreads)
20 torch.setdefaulttensortype('torch.DoubleTensor')
21 torch.manualSeed(2)
22
23 a = nn.Linear(10, 10)
24 b = nn.ReLU()
25 c = nn.Linear(10, 3)
26 d = nn.Linear(10, 3)
27 e = nn.CMulTable()
28 f = nn.Linear(3, 2)
29
30 --   a -----> b ---> c ----> e ---
31 --             \           /
32 --              \--> d ---/
33 --                    \
34 --                     \---> f ---
35
36 g = nn.DAG()
37
38 g:addEdge(c, e)
39 g:addEdge(a, b)
40 g:addEdge(d, e)
41 g:addEdge(b, c)
42 g:addEdge(b, d)
43 g:addEdge(d, f)
44
45 g:setInput({{a}})
46 g:setOutput({ e, f })
47
48 g:print()
49
50 input = torch.Tensor(3, 10):uniform()
51
52 output = g:updateOutput({{ input }})
53
54 printTensorTable(output)
55
56 ----------------------------------------------------------------------
57
58 print('******************************************************************')
59 print('** updateGradInput ***********************************************')
60 print('******************************************************************')
61 gradInput = g:updateGradInput({{input}}, output)
62
63 printTensorTable(gradInput)