b394a332898ac52f72e80201382e50ce95922ef6
[profiler-torch.git] / test-profiler.lua
1 #!/usr/bin/env luajit
2
3 require 'torch'
4 require 'nn'
5
6 require 'profiler'
7
8 local model = nn.Sequential()
9 model:add(nn.Linear(1000, 1000))
10 model:add(nn.ReLU())
11 model:add(nn.Linear(1000, 100))
12
13 profiler.decor(model)
14
15 for k = 1, 10 do
16    local input = torch.Tensor(1000, 1000):uniform(-1, 1)
17    local target = torch.Tensor(input:size(1), 100):uniform()
18    local criterion = nn.MSECriterion()
19    local output = model:forward(input)
20    local loss = criterion:forward(output, target)
21    local dloss = criterion:backward(output, target)
22    model:backward(input, dloss)
23 end
24
25 profiler.print(model)