automatic commit
[folded-ctf.git] / detector.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include "tools.h"
22 #include "detector.h"
23 #include "global.h"
24 #include "classifier_reader.h"
25 #include "pose_cell_hierarchy_reader.h"
26
27 Detector::Detector() {
28   _hierarchy = 0;
29   _nb_levels = 0;
30   _nb_classifiers_per_level = 0;
31   _thresholds = 0;
32   _nb_classifiers = 0;
33   _classifiers = 0;
34   _pi_feature_families = 0;
35 }
36
37
38 Detector::~Detector() {
39   if(_hierarchy) {
40     delete[] _thresholds;
41     for(int q = 0; q < _nb_classifiers; q++) {
42       delete _classifiers[q];
43       delete _pi_feature_families[q];
44     }
45     delete[] _classifiers;
46     delete[] _pi_feature_families;
47     delete _hierarchy;
48   }
49 }
50
51 //////////////////////////////////////////////////////////////////////
52 // Training
53
54 void Detector::train_classifier(int level,
55                                 LossMachine *loss_machine,
56                                 ParsingPool *parsing_pool,
57                                 PiFeatureFamily *pi_feature_family,
58                                 Classifier *classifier) {
59
60   // Randomize the pi-feature family
61
62   PiFeatureFamily full_pi_feature_family;
63
64   full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
65   full_pi_feature_family.randomize(level);
66
67   int nb_positives = parsing_pool->nb_positive_cells();
68
69   int nb_negatives_to_sample =
70     parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
71
72   SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
73                                         nb_positives + nb_negatives_to_sample);
74
75   scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
76
77   parsing_pool->weighted_sampling(loss_machine,
78                                   &full_pi_feature_family,
79                                   sample_set,
80                                   responses);
81
82   (*global.log_stream) << "Initial train_loss "
83                        << loss_machine->loss(sample_set, responses)
84                        << endl;
85
86   classifier->train(loss_machine, sample_set, responses);
87   classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
88
89   delete[] responses;
90   delete sample_set;
91 }
92
93 void Detector::train(LabelledImagePool *train_pool,
94                      LabelledImagePool *validation_pool,
95                      LabelledImagePool *hierarchy_pool) {
96
97   if(_hierarchy) {
98     cerr << "Can not re-train a Detector" << endl;
99     exit(1);
100   }
101
102   _hierarchy = new PoseCellHierarchy(hierarchy_pool);
103
104   int nb_violations;
105
106   nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
107
108   if(nb_violations > 0) {
109     cout << "The hierarchy is incompatible with the training set ("
110          << nb_violations
111          << " violations)." << endl;
112     exit(1);
113   }
114
115   nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
116
117   if(nb_violations > 0) {
118     cout << "The hierarchy is incompatible with the validation set ("
119          << nb_violations << " violations)."
120          << endl;
121     exit(1);
122   }
123
124   _nb_levels = _hierarchy->nb_levels();
125   _nb_classifiers_per_level = global.nb_classifiers_per_level;
126   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
127   _thresholds = new scalar_t[_nb_classifiers];
128   _classifiers = new Classifier *[_nb_classifiers];
129   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
130
131   for(int q = 0; q < _nb_classifiers; q++) {
132     _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
133     _pi_feature_families[q] = new PiFeatureFamily();
134   }
135
136   ParsingPool *train_parsing, *validation_parsing;
137
138   train_parsing = new ParsingPool(train_pool,
139                                   _hierarchy,
140                                   global.proportion_negative_cells_for_training);
141
142   if(global.write_validation_rocs) {
143     validation_parsing = new ParsingPool(validation_pool,
144                                          _hierarchy,
145                                          global.proportion_negative_cells_for_training);
146   } else {
147     validation_parsing = 0;
148   }
149
150   LossMachine *loss_machine = new LossMachine(global.loss_type);
151
152   cout << "Building a detector." << endl;
153
154   global.bar.init(&cout, _nb_classifiers);
155
156   for(int l = 0; l < _nb_levels; l++) {
157
158     if(l > 0) {
159       train_parsing->down_one_level(loss_machine, _hierarchy, l);
160       if(validation_parsing) {
161         validation_parsing->down_one_level(loss_machine, _hierarchy, l);
162       }
163     }
164
165     for(int c = 0; c < _nb_classifiers_per_level; c++) {
166       int q = l * _nb_classifiers_per_level + c;
167
168       // Train the classifier
169
170       train_classifier(l,
171                        loss_machine,
172                        train_parsing,
173                        _pi_feature_families[q], _classifiers[q]);
174
175       // Update the cell responses on the training set
176
177       train_parsing->update_cell_responses(_pi_feature_families[q],
178                                            _classifiers[q]);
179
180       // Save the ROC curves on the training set
181
182       char buffer[buffer_size];
183
184       sprintf(buffer, "%s/train_%05d.roc",
185               global.result_path,
186               (q + 1) * global.nb_weak_learners_per_classifier);
187       ofstream out(buffer);
188       train_parsing->write_roc(&out);
189
190       if(validation_parsing) {
191
192         // Update the cell responses on the validation set
193
194         validation_parsing->update_cell_responses(_pi_feature_families[q],
195                                                   _classifiers[q]);
196
197         // Save the ROC curves on the validation set
198
199         sprintf(buffer, "%s/validation_%05d.roc",
200                 global.result_path,
201                 (q + 1) * global.nb_weak_learners_per_classifier);
202         ofstream out(buffer);
203         validation_parsing->write_roc(&out);
204       }
205
206       _thresholds[q] = 0.0;
207
208       global.bar.refresh(&cout, q);
209     }
210   }
211
212   global.bar.finish(&cout);
213
214   delete loss_machine;
215   delete train_parsing;
216   delete validation_parsing;
217 }
218
219 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
220   LabelledImage *image;
221   int nb_targets_total = 0;
222
223   for(int i = 0; i < validation_pool->nb_images(); i++) {
224     image = validation_pool->grab_image(i);
225     nb_targets_total += image->nb_targets();
226     validation_pool->release_image(i);
227   }
228
229   scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
230
231   int tt = 0;
232
233   for(int i = 0; i < validation_pool->nb_images(); i++) {
234     image = validation_pool->grab_image(i);
235     image->compute_rich_structure();
236
237     PoseCell current_cell;
238
239     for(int t = 0; t < image->nb_targets(); t++) {
240
241       scalar_t response = 0;
242
243       for(int l = 0; l < _nb_levels; l++) {
244
245         // We get the next-level cell for that target
246
247         PoseCellSet cell_set;
248
249         cell_set.erase_content();
250         if(l == 0) {
251           _hierarchy->add_root_cells(image, &cell_set);
252         } else {
253           _hierarchy->add_subcells(l, &current_cell, &cell_set);
254         }
255
256         int nb_compliant = 0;
257
258         for(int c = 0; c < cell_set.nb_cells(); c++) {
259           if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
260             current_cell = *(cell_set.get_cell(c));
261             nb_compliant++;
262           }
263         }
264
265         if(nb_compliant != 1) {
266           cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
267           abort();
268         }
269
270         for(int c = 0; c < _nb_classifiers_per_level; c++) {
271           int q = l * _nb_classifiers_per_level + c;
272           SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
273           sample_set->set_sample(0, _pi_feature_families[q], image, &current_cell, 0);
274           response +=_classifiers[q]->response(sample_set, 0);
275           delete sample_set;
276           responses[tt + nb_targets_total * q] = response;
277         }
278
279       }
280
281       tt++;
282     }
283
284     validation_pool->release_image(i);
285   }
286
287   ASSERT(tt == nb_targets_total);
288
289   // Here we have in responses[] all the target responses after every
290   // classifier
291
292   int *still_detected = new int[nb_targets_total];
293   int *indexes = new int[nb_targets_total];
294   int *sorted_indexes = new int[nb_targets_total];
295
296   for(int t = 0; t < nb_targets_total; t++) {
297     still_detected[t] = 1;
298     indexes[t] = t;
299   }
300
301   int current_nb_fn = 0;
302
303   for(int q = 0; q < _nb_classifiers; q++) {
304
305     scalar_t wanted_tp_at_this_classifier
306       = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
307
308     int wanted_nb_fn_at_this_classifier
309       = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
310
311     indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
312                         responses + q * nb_targets_total);
313
314     for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
315       int u = sorted_indexes[t];
316       int v = sorted_indexes[t+1];
317       _thresholds[q] = responses[v + nb_targets_total * q];
318       if(still_detected[u]) {
319         still_detected[u] = 0;
320         current_nb_fn++;
321       }
322     }
323   }
324
325   delete[] still_detected;
326   delete[] indexes;
327   delete[] sorted_indexes;
328   delete[] responses;
329 }
330
331 //////////////////////////////////////////////////////////////////////
332 // Parsing
333
334 void Detector::parse_rec(RichImage *image, int level,
335                          PoseCell *cell, scalar_t current_response,
336                          PoseCellScoredSet *result) {
337
338   if(level == _nb_levels) {
339     result->add_cell_with_score(cell, current_response);
340     return;
341   }
342
343   PoseCellSet cell_set;
344   cell_set.erase_content();
345
346   if(level == 0) {
347     _hierarchy->add_root_cells(image, &cell_set);
348   } else {
349     _hierarchy->add_subcells(level, cell, &cell_set);
350   }
351
352   scalar_t *responses = new scalar_t[cell_set.nb_cells()];
353   int *keep = new int[cell_set.nb_cells()];
354
355   for(int c = 0; c < cell_set.nb_cells(); c++) {
356     responses[c] = current_response;
357     keep[c] = 1;
358   }
359
360   for(int a = 0; a < _nb_classifiers_per_level; a++) {
361     int q = level * _nb_classifiers_per_level + a;
362     SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
363     for(int c = 0; c < cell_set.nb_cells(); c++) {
364       if(keep[c]) {
365         samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
366         responses[c] += _classifiers[q]->response(samples, 0);
367         keep[c] = responses[c] >= _thresholds[q];
368       }
369     }
370     delete samples;
371   }
372
373   for(int c = 0; c < cell_set.nb_cells(); c++) {
374     if(keep[c]) {
375       parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
376     }
377   }
378
379   delete[] keep;
380   delete[] responses;
381 }
382
383 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
384   result_cell_set->erase_content();
385   parse_rec(image, 0, 0, 0, result_cell_set);
386 }
387
388 //////////////////////////////////////////////////////////////////////
389 // Storage
390
391 void Detector::read(istream *is) {
392   if(_hierarchy) {
393     cerr << "Can not read over an existing Detector" << endl;
394     exit(1);
395   }
396
397   read_var(is, &_nb_levels);
398   read_var(is, &_nb_classifiers_per_level);
399
400   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
401
402   _classifiers = new Classifier *[_nb_classifiers];
403   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
404   _thresholds = new scalar_t[_nb_classifiers];
405
406   for(int q = 0; q < _nb_classifiers; q++) {
407     _pi_feature_families[q] = new PiFeatureFamily();
408     _pi_feature_families[q]->read(is);
409     _classifiers[q] = read_classifier(is);
410     read_var(is, &_thresholds[q]);
411   }
412
413   _hierarchy = read_hierarchy(is);
414 }
415
416 void Detector::write(ostream *os) {
417   write_var(os, &_nb_levels);
418   write_var(os, &_nb_classifiers_per_level);
419
420   for(int q = 0; q < _nb_classifiers; q++) {
421     _pi_feature_families[q]->write(os);
422     _classifiers[q]->write(os);
423     write_var(os, &_thresholds[q]);
424   }
425
426   _hierarchy->write(os);
427 }