The analytic gradient checks out.
[dagnn.git] / test-dagnn.lua
1 #!/usr/bin/env luajit
2
3 require 'torch'
4 require 'nn'
5
6 require 'dagnn'
7
8 function checkGrad(model, criterion, input, target)
9    local params, gradParams = model:getParameters()
10
11    local epsilon = 1e-5
12
13    local output = model:forward(input)
14    local loss = criterion:forward(output, target)
15    local gradOutput = criterion:backward(output, target)
16    gradParams:zero()
17    model:backward(input, gradOutput)
18    local analyticalGradParam = gradParams:clone()
19
20    for i = 1, params:size(1) do
21       local x = params[i]
22
23       params[i] = x - epsilon
24       local output0 = model:forward(input)
25       local loss0 = criterion:forward(output0, target)
26
27       params[i] = x + epsilon
28       local output1 = model:forward(input)
29       local loss1 = criterion:forward(output1, target)
30
31       params[i] = x
32
33       local ana = analyticalGradParam[i]
34       local num = (loss1 - loss0) / (2 * epsilon)
35       local err = torch.abs(num - ana) / torch.abs(num)
36
37       print(
38          err .. ' checkGrad ' .. i
39             .. ' analytical ' .. ana
40             .. ' numerical ' .. num
41       )
42    end
43
44 end
45
46 function printTensorTable(t)
47    if torch.type(t) == 'table' then
48       for i, t in pairs(t) do
49          print('-- ELEMENT [' .. i .. '] --')
50          printTensorTable(t)
51       end
52    else
53       print(tostring(t))
54    end
55 end
56
57 -- torch.setnumthreads(params.nbThreads)
58 torch.setdefaulttensortype('torch.DoubleTensor')
59 torch.manualSeed(2)
60
61 --                     +--> c ----> e --+
62 --                    /            /     \
63 --                   /            /       \
64 --  input --> a --> b ---> d ----+         g --> output
65 --                          \             /
66 --                           \           /
67 --                            +--> f ---+
68
69 a = nn.Linear(10, 10)
70 b = nn.ReLU()
71 c = nn.Linear(10, 3)
72 d = nn.Linear(10, 3)
73 e = nn.CMulTable()
74 f = nn.Linear(3, 3)
75 g = nn.CAddTable()
76
77 ----------------------------------------------------------------------
78
79 model = nn.DAG()
80
81 model:addEdge(a, b)
82 model:addEdge(b, c)
83 model:addEdge(b, d)
84 model:addEdge(c, e)
85 model:addEdge(d, e)
86 model:addEdge(d, f)
87 model:addEdge(e, g)
88 model:addEdge(f, g)
89
90 model:setInput(a)
91 model:setOutput(g)
92
93 input = torch.Tensor(3, 10):uniform()
94
95 print('******************************************************************')
96 print('** updateOutput **************************************************')
97 print('******************************************************************')
98
99 output = model:updateOutput(input):clone()
100
101 printTensorTable(output)
102
103 print('******************************************************************')
104 print('** updateGradInput ***********************************************')
105 print('******************************************************************')
106
107 gradInput = model:updateGradInput(input, output)
108
109 printTensorTable(gradInput)
110
111 print('******************************************************************')
112 print('** checkGrad *****************************************************')
113 print('******************************************************************')
114
115 output:uniform()
116
117 checkGrad(model, nn.MSECriterion(), input, output)