2 * folded-ctf is an implementation of the folded hierarchy of
3 * classifiers for object detection, developed by Francois Fleuret
6 * Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7 * Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 * This file is part of folded-ctf.
11 * folded-ctf is free software: you can redistribute it and/or modify
12 * it under the terms of the GNU General Public License as published
13 * by the Free Software Foundation, either version 3 of the License,
14 * or (at your option) any later version.
16 * folded-ctf is distributed in the hope that it will be useful, but
17 * WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 * General Public License for more details.
21 * You should have received a copy of the GNU General Public License
22 * along with folded-ctf. If not, see <http://www.gnu.org/licenses/>.
29 #include "classifier_reader.h"
30 #include "pose_cell_hierarchy_reader.h"
32 Detector::Detector() {
35 _nb_classifiers_per_level = 0;
39 _pi_feature_families = 0;
43 Detector::~Detector() {
46 for(int q = 0; q < _nb_classifiers; q++) {
47 delete _classifiers[q];
48 delete _pi_feature_families[q];
50 delete[] _classifiers;
51 delete[] _pi_feature_families;
56 //////////////////////////////////////////////////////////////////////
59 void Detector::train_classifier(int level,
60 LossMachine *loss_machine,
61 ParsingPool *parsing_pool,
62 PiFeatureFamily *pi_feature_family,
63 Classifier *classifier) {
65 // Randomize the pi-feature family
67 PiFeatureFamily full_pi_feature_family;
69 full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
70 full_pi_feature_family.randomize(level);
72 int nb_positives = parsing_pool->nb_positive_cells();
74 int nb_negatives_to_sample =
75 parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
77 SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
78 nb_positives + nb_negatives_to_sample);
80 scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
82 parsing_pool->weighted_sampling(loss_machine,
83 &full_pi_feature_family,
87 (*global.log_stream) << "Initial train_loss "
88 << loss_machine->loss(sample_set, responses)
91 classifier->train(loss_machine, sample_set, responses);
92 classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
98 void Detector::train(LabelledImagePool *train_pool,
99 LabelledImagePool *validation_pool,
100 LabelledImagePool *hierarchy_pool) {
103 cerr << "Can not re-train a Detector" << endl;
107 _hierarchy = new PoseCellHierarchy(hierarchy_pool);
111 nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
113 if(nb_violations > 0) {
114 cout << "The hierarchy is incompatible with the training set ("
116 << " violations)." << endl;
120 nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
122 if(nb_violations > 0) {
123 cout << "The hierarchy is incompatible with the validation set ("
124 << nb_violations << " violations)."
129 _nb_levels = _hierarchy->nb_levels();
130 _nb_classifiers_per_level = global.nb_classifiers_per_level;
131 _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
132 _thresholds = new scalar_t[_nb_classifiers];
133 _classifiers = new Classifier *[_nb_classifiers];
134 _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
136 for(int q = 0; q < _nb_classifiers; q++) {
137 _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
138 _pi_feature_families[q] = new PiFeatureFamily();
141 ParsingPool *train_parsing, *validation_parsing;
143 train_parsing = new ParsingPool(train_pool,
145 global.proportion_negative_cells_for_training);
147 if(global.write_validation_rocs) {
148 validation_parsing = new ParsingPool(validation_pool,
150 global.proportion_negative_cells_for_training);
152 validation_parsing = 0;
155 LossMachine *loss_machine = new LossMachine(global.loss_type);
157 cout << "Building a detector." << endl;
159 global.bar.init(&cout, _nb_classifiers);
161 for(int l = 0; l < _nb_levels; l++) {
164 train_parsing->down_one_level(loss_machine, _hierarchy, l);
165 if(validation_parsing) {
166 validation_parsing->down_one_level(loss_machine, _hierarchy, l);
170 for(int c = 0; c < _nb_classifiers_per_level; c++) {
171 int q = l * _nb_classifiers_per_level + c;
173 // Train the classifier
178 _pi_feature_families[q], _classifiers[q]);
180 // Update the cell responses on the training set
182 train_parsing->update_cell_responses(_pi_feature_families[q],
185 // Save the ROC curves on the training set
187 char buffer[buffer_size];
189 sprintf(buffer, "%s/train_%05d.roc",
191 (q + 1) * global.nb_weak_learners_per_classifier);
192 ofstream out(buffer);
193 train_parsing->write_roc(&out);
195 if(validation_parsing) {
197 // Update the cell responses on the validation set
199 validation_parsing->update_cell_responses(_pi_feature_families[q],
202 // Save the ROC curves on the validation set
204 sprintf(buffer, "%s/validation_%05d.roc",
206 (q + 1) * global.nb_weak_learners_per_classifier);
207 ofstream out(buffer);
208 validation_parsing->write_roc(&out);
211 _thresholds[q] = 0.0;
213 global.bar.refresh(&cout, q);
217 global.bar.finish(&cout);
220 delete train_parsing;
221 delete validation_parsing;
224 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
225 LabelledImage *image;
226 int nb_targets_total = 0;
228 for(int i = 0; i < validation_pool->nb_images(); i++) {
229 image = validation_pool->grab_image(i);
230 nb_targets_total += image->nb_targets();
231 validation_pool->release_image(i);
234 scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
238 for(int i = 0; i < validation_pool->nb_images(); i++) {
239 image = validation_pool->grab_image(i);
240 image->compute_rich_structure();
242 PoseCell current_cell;
244 for(int t = 0; t < image->nb_targets(); t++) {
246 scalar_t response = 0;
248 for(int l = 0; l < _nb_levels; l++) {
250 // We get the next-level cell for that target
252 PoseCellSet cell_set;
254 cell_set.erase_content();
256 _hierarchy->add_root_cells(image, &cell_set);
258 _hierarchy->add_subcells(l, ¤t_cell, &cell_set);
261 int nb_compliant = 0;
263 for(int c = 0; c < cell_set.nb_cells(); c++) {
264 if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
265 current_cell = *(cell_set.get_cell(c));
270 if(nb_compliant != 1) {
271 cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
275 for(int c = 0; c < _nb_classifiers_per_level; c++) {
276 int q = l * _nb_classifiers_per_level + c;
277 SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
278 sample_set->set_sample(0, _pi_feature_families[q], image, ¤t_cell, 0);
279 response +=_classifiers[q]->response(sample_set, 0);
281 responses[tt + nb_targets_total * q] = response;
289 validation_pool->release_image(i);
292 ASSERT(tt == nb_targets_total);
294 // Here we have in responses[] all the target responses after every
297 int *still_detected = new int[nb_targets_total];
298 int *indexes = new int[nb_targets_total];
299 int *sorted_indexes = new int[nb_targets_total];
301 for(int t = 0; t < nb_targets_total; t++) {
302 still_detected[t] = 1;
306 int current_nb_fn = 0;
308 for(int q = 0; q < _nb_classifiers; q++) {
310 scalar_t wanted_tp_at_this_classifier
311 = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
313 int wanted_nb_fn_at_this_classifier
314 = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
316 indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
317 responses + q * nb_targets_total);
319 for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
320 int u = sorted_indexes[t];
321 int v = sorted_indexes[t+1];
322 _thresholds[q] = responses[v + nb_targets_total * q];
323 if(still_detected[u]) {
324 still_detected[u] = 0;
330 delete[] still_detected;
332 delete[] sorted_indexes;
336 //////////////////////////////////////////////////////////////////////
339 void Detector::parse_rec(RichImage *image, int level,
340 PoseCell *cell, scalar_t current_response,
341 PoseCellScoredSet *result) {
343 if(level == _nb_levels) {
344 result->add_cell_with_score(cell, current_response);
348 PoseCellSet cell_set;
349 cell_set.erase_content();
352 _hierarchy->add_root_cells(image, &cell_set);
354 _hierarchy->add_subcells(level, cell, &cell_set);
357 scalar_t *responses = new scalar_t[cell_set.nb_cells()];
358 int *keep = new int[cell_set.nb_cells()];
360 for(int c = 0; c < cell_set.nb_cells(); c++) {
361 responses[c] = current_response;
365 for(int a = 0; a < _nb_classifiers_per_level; a++) {
366 int q = level * _nb_classifiers_per_level + a;
367 SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
368 for(int c = 0; c < cell_set.nb_cells(); c++) {
370 samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
371 responses[c] += _classifiers[q]->response(samples, 0);
372 keep[c] = responses[c] >= _thresholds[q];
378 for(int c = 0; c < cell_set.nb_cells(); c++) {
380 parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
388 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
389 result_cell_set->erase_content();
390 parse_rec(image, 0, 0, 0, result_cell_set);
393 //////////////////////////////////////////////////////////////////////
396 void Detector::read(istream *is) {
398 cerr << "Can not read over an existing Detector" << endl;
402 read_var(is, &_nb_levels);
403 read_var(is, &_nb_classifiers_per_level);
405 _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
407 _classifiers = new Classifier *[_nb_classifiers];
408 _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
409 _thresholds = new scalar_t[_nb_classifiers];
411 for(int q = 0; q < _nb_classifiers; q++) {
412 _pi_feature_families[q] = new PiFeatureFamily();
413 _pi_feature_families[q]->read(is);
414 _classifiers[q] = read_classifier(is);
415 read_var(is, &_thresholds[q]);
418 _hierarchy = read_hierarchy(is);
421 void Detector::write(ostream *os) {
422 write_var(os, &_nb_levels);
423 write_var(os, &_nb_classifiers_per_level);
425 for(int q = 0; q < _nb_classifiers; q++) {
426 _pi_feature_families[q]->write(os);
427 _classifiers[q]->write(os);
428 write_var(os, &_thresholds[q]);
431 _hierarchy->write(os);