projects
/
profiler-torch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Fixed the error message, which was causing an error.
[profiler-torch.git]
/
test-profiler.lua
diff --git
a/test-profiler.lua
b/test-profiler.lua
index
7c22576
..
aa6e800
100755
(executable)
--- a/
test-profiler.lua
+++ b/
test-profiler.lua
@@
-39,21
+39,28
@@
require 'profiler'
-- Create a model
-- Create a model
+local w, h, fs = 50, 50, 3
+local nhu = (w - fs + 1) * (h - fs + 1)
+
local model = nn.Sequential()
:add(nn.Sequential()
local model = nn.Sequential()
:add(nn.Sequential()
- :add(nn.Linear(1000, 1000))
+ :add(nn.SpatialConvolution(1, 1, fs, fs))
+ :add(nn.Reshape(nhu))
+ :add(nn.Linear(nhu, 1000))
:add(nn.ReLU())
)
:add(nn.Linear(1000, 100))
-- Decor it for profiling
:add(nn.ReLU())
)
:add(nn.Linear(1000, 100))
-- Decor it for profiling
-profiler.decor(model)
+profiler.decor
ate
(model)
print()
print()
+torch.save('model.t7', model)
+
-- Create the data and criterion
-- Create the data and criterion
-local input = torch.Tensor(1000, 1
000
)
+local input = torch.Tensor(1000, 1
, h, w
)
local target = torch.Tensor(input:size(1), 100)
local criterion = nn.MSECriterion()
local target = torch.Tensor(input:size(1), 100)
local criterion = nn.MSECriterion()
@@
-86,9
+93,9
@@
end
-- Print the accumulated timings
-- Print the accumulated timings
-profiler.print(model, nbSamples)
+-- profiler.color = false
+profiler.print(model, nbSamples, modelTime)
-- profiler.print(model)
-- profiler.print(model)
-print()
print(string.format('Total model time %.02fs', modelTime))
print(string.format('Total data time %.02fs', dataTime))
print(string.format('Total model time %.02fs', modelTime))
print(string.format('Total data time %.02fs', dataTime))