Added README.md
[dyncnn.git] / fftb.lua
1
2 -- Francois Fleuret's Torch Toolbox
3
4 require 'torch'
5 require 'nn'
6
7 ----------------------------------------------------------------------
8
9 colors = sys.COLORS
10
11 function printf(f, ...)
12    print(string.format(f, unpack({...})))
13 end
14
15 function printfc(c, f, ...)
16    printf(c .. string.format(f, unpack({...})) .. colors.black)
17 end
18
19 function logCommand(c)
20    print(colors.blue .. '[' .. c .. '] -> [' .. sys.execute(c) .. ']' .. colors.black)
21 end
22
23 ----------------------------------------------------------------------
24 -- Environment variables
25
26 defaultNbThreads = 1
27 defaultUseGPU = false
28
29 if os.getenv('TORCH_NB_THREADS') then
30    defaultNbThreads = os.getenv('TORCH_NB_THREADS')
31    print('Environment variable TORCH_NB_THREADS is set and equal to ' .. defaultNbThreads)
32 else
33    print('Environment variable TORCH_NB_THREADS is not set, default is ' .. defaultNbThreads)
34 end
35
36 if os.getenv('TORCH_USE_GPU') then
37    defaultUseGPU = os.getenv('TORCH_USE_GPU') == 'yes'
38    print('Environment variable TORCH_USE_GPU is set and evaluated as ' .. tostring(defaultUseGPU))
39 else
40    print('Environment variable TORCH_USE_GPU is not set, default is ' .. tostring(defaultUseGPU))
41 end
42
43 ----------------------------------------------------------------------
44
45 function fftbInit(cmd, params)
46
47    torch.setnumthreads(params.nbThreads)
48    torch.setdefaulttensortype('torch.FloatTensor')
49    torch.manualSeed(params.seed)
50
51    -- Logging
52
53    if params.rundir == '' then
54       params.rundir = cmd:string('experiment', params, { })
55    end
56
57    paths.mkdir(params.rundir)
58
59    if not params.noLog then
60       -- Append to the log if there is one
61       cmd:log(io.open(params.rundir .. '/log', 'a'), params)
62    end
63
64    -- Dealing with the CPU/GPU
65
66    ffnn = {}
67
68    -- By default, ffnn returns the entries from nn
69    local mt = {}
70    function mt.__index(table, key)
71       return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
72    end
73    setmetatable(ffnn, mt)
74
75    -- These are the tensors that can be kept on the CPU
76    ffnn.SlowTensor = torch.Tensor
77    ffnn.SlowStorage = torch.Storage
78    -- These are the tensors that should be moved to the GPU
79    ffnn.FastTensor = torch.Tensor
80    ffnn.FastStorage = torch.Storage
81
82    if params.useGPU then
83       require 'cutorch'
84       require 'cunn'
85       require 'cudnn'
86
87       if params.fastGPU then
88          cudnn.benchmark = true
89          cudnn.fastest = true
90       end
91
92       ffnn.FastTensor = torch.CudaTensor
93       ffnn.FastStorage = torch.CudaStorage
94    end
95 end
96
97 ----------------------------------------------------------------------
98
99 function dimAtThatPoint(model, input)
100    if params.useGPU then
101       model:cuda()
102    end
103    local i = ffnn.FastTensor(input:narrow(1, 1, 1):size()):copy(input:narrow(1, 1, 1))
104    return model:forward(i):nElement()
105 end
106
107 ----------------------------------------------------------------------
108
109 function sizeForBatch(n, x)
110    local size = x:size()
111    size[1] = n
112    return size
113 end
114
115 function fillBatch(data, first, batch, permutation)
116    local actualBatchSize = math.min(params.batchSize, data.input:size(1) - first + 1)
117
118    if batch.input then
119       if actualBatchSize ~= batch.input:size(1) then
120          batch.input:resize(sizeForBatch(actualBatchSize, batch.input))
121       end
122    else
123       if torch.isTypeOf(data.input, ffnn.SlowTensor) then
124          batch.input = ffnn.FastTensor(sizeForBatch(actualBatchSize, data.input));
125       else
126          batch.input = data.input.new():resize(sizeForBatch(actualBatchSize, data.input));
127       end
128    end
129
130    if batch.target then
131       if actualBatchSize ~= batch.target:size(1) then
132          batch.target:resize(sizeForBatch(actualBatchSize, batch.target))
133       end
134    else
135       if torch.isTypeOf(data.target, ffnn.SlowTensor) then
136          batch.target = ffnn.FastTensor(sizeForBatch(actualBatchSize, data.target));
137       else
138          batch.target = data.target.new():resize(sizeForBatch(actualBatchSize, data.target));
139       end
140    end
141
142    for k = 1, actualBatchSize do
143       local i
144       if permutation then
145          i = permutation[first + k - 1]
146       else
147          i = first + k - 1
148       end
149       batch.input[k] = data.input[i]
150       batch.target[k] = data.target[i]
151    end
152 end
153
154 ----------------------------------------------------------------------
155
156 --[[
157
158 The combineImage function takes as input a parameter c which is the
159 value to use for the background of the resulting image (padding and
160 such), and t which is either a 2d tensor, a 3d tensor, or a table.
161
162  * If t is a 3d tensor, it is returned unchanged.
163
164  * If t is a 2d tensor [r x c], it is reshaped to [1 x r x c] and
165    returned.
166
167  * If t is a table, combineImage first calls itself recursively on
168    t[1], t[2], etc.
169
170    It then creates a new tensor by concatenating the results
171    horizontally if t.vertical is nil, vertically otherwise.
172
173    It adds a padding of t.pad pixels if this field is set.
174
175  * Example
176
177    x = torch.Tensor(64, 64):fill(0.5)
178    y = torch.Tensor(100, 30):fill(0.85)
179
180    i = combineImages(1.0,
181       {
182          pad = 1,
183          vertical = true,
184          { pad = 1, x },
185          {
186             y,
187             { pad = 4, torch.Tensor(32, 16):fill(0.25) },
188             { pad = 1, torch.Tensor(45, 54):uniform(0.25, 0.9) },
189          }
190       }
191    )
192
193    image.save('example.png', i)
194
195 ]]--
196
197 function combineImages(c, t)
198
199    if torch.isTensor(t) then
200
201       if t:dim() == 3 then
202          return t
203       elseif t:dim() == 2 then
204          return torch.Tensor(1, t:size(1), t:size(2)):copy(t)
205       else
206          error('can only deal with [height x width] or [channel x height x width] tensors.')
207       end
208
209    else
210
211       local subImages = {} -- The subimages
212       local nc = 0 -- Nb of columns
213       local nr = 0 -- Nb of rows
214
215       for i, x in ipairs(t) do
216          subImages[i] = combineImages(c, x)
217          if t.vertical then
218             nr = nr + subImages[i]:size(2)
219             nc = math.max(nc, subImages[i]:size(3))
220          else
221             nr = math.max(nr, subImages[i]:size(2))
222             nc = nc + subImages[i]:size(3)
223          end
224       end
225
226       local pad = t.pad or 0
227       local result = torch.Tensor(subImages[1]:size(1), nr + 2 * pad, nc + 2 * pad):fill(c)
228       local co = 1 + pad -- Origin column
229       local ro = 1 + pad -- Origin row
230
231       for i in ipairs(t) do
232
233          result
234             :sub(1, subImages[1]:size(1),
235                  ro, ro + subImages[i]:size(2) - 1,
236                  co, co + subImages[i]:size(3) - 1)
237             :copy(subImages[i])
238
239          if t.vertical then
240             ro = ro + subImages[i]:size(2)
241          else
242             co = co + subImages[i]:size(3)
243          end
244
245       end
246
247       return result
248
249    end
250
251 end
252
253 --[[
254
255 The imageFromTensors function gets as input a list of tensors of
256 arbitrary dimensions each, but whose two last dimensions stand for
257 height x width. It creates an image tensor (2d, one channel) with each
258 argument tensor unfolded per row.
259
260 ]]--
261
262 function imageFromTensors(bt, signed)
263    local gap = 1
264    local tgap = -1
265    local width = 0
266    local height = gap
267
268    for _, t in pairs(bt) do
269       local d = t:dim()
270       local h, w = t:size(d - 1), t:size(d)
271       local n = t:nElement() / (w * h)
272       width = math.max(width, gap + n * (gap + w))
273       height = height + gap + tgap + gap + h
274    end
275
276    local e = torch.Tensor(3, height, width):fill(1.0)
277    local y0 = 1 + gap
278
279    for _, t in pairs(bt) do
280       local d = t:dim()
281       local h, w = t:size(d - 1), t:size(d)
282       local n = t:nElement() / (w * h)
283       local z = t:norm() / math.sqrt(t:nElement())
284
285       local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 )
286       local u = torch.Tensor(t:size()):copy(t):resize(n, h, w)
287       for m = 1, n do
288
289          for c = 1, 3 do
290             for y = 0, h+1 do
291                e[c][y0 + y - 1][x0     - 1] = 0.0
292                e[c][y0 + y - 1][x0 + w    ] = 0.0
293             end
294             for x = 0, w+1 do
295                e[c][y0     - 1][x0 + x - 1] = 0.0
296                e[c][y0 + h    ][x0 + x - 1] = 0.0
297             end
298          end
299
300          for y = 1, h do
301             for x = 1, w do
302                local v = u[m][y][x] / z
303                local r, g, b
304                if signed then
305                   if v < -1 then
306                      r, g, b = 0.0, 0.0, 1.0
307                   elseif v > 1 then
308                      r, g, b = 1.0, 0.0, 0.0
309                   elseif v >= 0 then
310                      r, g, b = 1.0, 1.0 - v, 1.0 - v
311                   else
312                      r, g, b = 1.0 + v, 1.0 + v, 1.0
313                   end
314                else
315                   if v <= 0 then
316                      r, g, b = 1.0, 1.0, 1.0
317                   elseif v > 1 then
318                      r, g, b = 0.0, 0.0, 0.0
319                   else
320                      r, g, b = 1.0 - v, 1.0 - v, 1.0 - v
321                   end
322                end
323                e[1][y0 + y - 1][x0 + x - 1] = r
324                e[2][y0 + y - 1][x0 + x - 1] = g
325                e[3][y0 + y - 1][x0 + x - 1] = b
326             end
327          end
328          x0 = x0 + w + gap
329       end
330       y0 = y0 + h + gap + tgap + gap
331    end
332
333    return e
334 end