Use cudnn more efficiently.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 21 Oct 2016 10:23:03 +0000 (12:23 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 21 Oct 2016 10:23:03 +0000 (12:23 +0200)
dyncnn.lua

index 839431a..e104386 100755 (executable)
@@ -126,10 +126,10 @@ torch.manualSeed(opt.seed)
 
 mynn = {}
 
--- By default, mynn returns the entries from nn
+-- To deal elegantly with CPU/GPU
 local mt = {}
 function mt.__index(table, key)
-   return nn[key]
+   return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
 end
 setmetatable(mynn, mt)
 
@@ -144,8 +144,13 @@ if useGPU then
    require 'cutorch'
    require 'cunn'
    require 'cudnn'
+
    mynn.FastTensor = torch.CudaTensor
-   mynn.SpatialConvolution = cudnn.SpatialConvolution
+
+   if cudnn then
+      cudnn.benchmark = true
+      cudnn.fastest = true
+   end
 end
 
 ----------------------------------------------------------------------