From 8bc9a2b89adb6f08a76b4e393025a2f5b2999aa6 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 12 Jan 2017 18:46:06 +0100 Subject: [PATCH] Cosmetics. --- test-dagnn.lua | 40 ++++++++++------------------------------ 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/test-dagnn.lua b/test-dagnn.lua index 5b266da..1df04e2 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -21,9 +21,12 @@ require 'torch' require 'nn' - require 'dagnn' +-- torch.setnumthreads(params.nbThreads) +torch.setdefaulttensortype('torch.DoubleTensor') +torch.manualSeed(2) + function checkGrad(model, criterion, input, target) local params, gradParams = model:getParameters() @@ -81,10 +84,6 @@ function printTensorTable(t) end end --- torch.setnumthreads(params.nbThreads) -torch.setdefaulttensortype('torch.DoubleTensor') -torch.manualSeed(2) - -- +--> c ----> e --+ -- / / \ -- / / \ @@ -93,12 +92,12 @@ torch.manualSeed(2) -- \ / -- +--> f ---+ -a = nn.Linear(10, 10) +a = nn.Linear(50, 10) b = nn.ReLU() -c = nn.Linear(10, 3) -d = nn.Linear(10, 3) +c = nn.Linear(10, 15) +d = nn.Linear(10, 15) e = nn.CMulTable() -f = nn.Linear(3, 3) +f = nn.Linear(15, 15) g = nn.CAddTable() model = nn.DAG() @@ -115,27 +114,8 @@ model:addEdge(f, g) model:setInput(a) model:setOutput(g) -input = torch.Tensor(3, 10):uniform() - -print('******************************************************************') -print('** updateOutput **************************************************') -print('******************************************************************') - -output = model:updateOutput(input):clone() - -printTensorTable(output) - -print('******************************************************************') -print('** updateGradInput ***********************************************') -print('******************************************************************') - -gradInput = model:updateGradInput(input, output) - -printTensorTable(gradInput) - -print('******************************************************************') -print('** checkGrad *****************************************************') -print('******************************************************************') +local input = torch.Tensor(30, 50):uniform() +local output = model:updateOutput(input):clone() output:uniform() -- 2.20.1