projects
/
profiler-torch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
So, back to decorating the classes and not the objects so that torch.save() does...
[profiler-torch.git]
/
test-profiler.lua
diff --git
a/test-profiler.lua
b/test-profiler.lua
index
a78c944
..
18677ec
100755
(executable)
--- a/
test-profiler.lua
+++ b/
test-profiler.lua
@@
-39,9
+39,14
@@
require 'profiler'
-- Create a model
-- Create a model
+local w, h, fs = 50, 50, 3
+local nhu = (w - fs + 1) * (h - fs + 1)
+
local model = nn.Sequential()
:add(nn.Sequential()
local model = nn.Sequential()
:add(nn.Sequential()
- :add(nn.Linear(1000, 1000))
+ :add(nn.SpatialConvolution(1, 1, fs, fs))
+ :add(nn.Reshape(nhu))
+ :add(nn.Linear(nhu, 1000))
:add(nn.ReLU())
)
:add(nn.Linear(1000, 100))
:add(nn.ReLU())
)
:add(nn.Linear(1000, 100))
@@
-55,7
+60,7
@@
torch.save('model.t7', model)
-- Create the data and criterion
-- Create the data and criterion
-local input = torch.Tensor(1000, 1
000
)
+local input = torch.Tensor(1000, 1
, h, w
)
local target = torch.Tensor(input:size(1), 100)
local criterion = nn.MSECriterion()
local target = torch.Tensor(input:size(1), 100)
local criterion = nn.MSECriterion()
@@
-88,7
+93,7
@@
end
-- Print the accumulated timings
-- Print the accumulated timings
-profiler.print(model, nbSamples)
+profiler.print(model, nbSamples
, modelTime
)
-- profiler.print(model)
print()
-- profiler.print(model)
print()