--- /dev/null
+
+///////////////////////////////////////////////////////////////////////////
+// This program is free software: you can redistribute it and/or modify //
+// it under the terms of the version 3 of the GNU General Public License //
+// as published by the Free Software Foundation. //
+// //
+// This program is distributed in the hope that it will be useful, but //
+// WITHOUT ANY WARRANTY; without even the implied warranty of //
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU //
+// General Public License for more details. //
+// //
+// You should have received a copy of the GNU General Public License //
+// along with this program. If not, see <http://www.gnu.org/licenses/>. //
+// //
+// Written by Francois Fleuret, (C) IDIAP //
+// Contact <francois.fleuret@idiap.ch> for comments & bug reports //
+///////////////////////////////////////////////////////////////////////////
+
+#include "tools.h"
+#include "detector.h"
+#include "global.h"
+#include "classifier_reader.h"
+#include "pose_cell_hierarchy_reader.h"
+
+Detector::Detector() {
+ _hierarchy = 0;
+ _nb_levels = 0;
+ _nb_classifiers_per_level = 0;
+ _thresholds = 0;
+ _nb_classifiers = 0;
+ _classifiers = 0;
+ _pi_feature_families = 0;
+}
+
+
+Detector::~Detector() {
+ if(_hierarchy) {
+ delete[] _thresholds;
+ for(int q = 0; q < _nb_classifiers; q++) {
+ delete _classifiers[q];
+ delete _pi_feature_families[q];
+ }
+ delete[] _classifiers;
+ delete[] _pi_feature_families;
+ delete _hierarchy;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////
+// Training
+
+void Detector::train_classifier(int level,
+ LossMachine *loss_machine,
+ ParsingPool *parsing_pool,
+ PiFeatureFamily *pi_feature_family,
+ Classifier *classifier) {
+
+ // Randomize the pi-feature family
+
+ PiFeatureFamily full_pi_feature_family;
+
+ full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
+ full_pi_feature_family.randomize(level);
+
+ int nb_positives = parsing_pool->nb_positive_cells();
+
+ int nb_negatives_to_sample =
+ parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
+
+ SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
+ nb_positives + nb_negatives_to_sample);
+
+ scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
+
+ (*global.log_stream) << "Collecting the sampled training set." << endl;
+
+ parsing_pool->weighted_sampling(loss_machine,
+ &full_pi_feature_family,
+ sample_set,
+ responses);
+
+ (*global.log_stream) << "Training the classifier." << endl;
+
+ (*global.log_stream) << "Initial train_loss "
+ << loss_machine->loss(sample_set, responses)
+ << endl;
+
+ classifier->train(loss_machine, sample_set, responses);
+ classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
+
+ delete[] responses;
+ delete sample_set;
+}
+
+void Detector::train(LabelledImagePool *train_pool,
+ LabelledImagePool *validation_pool,
+ LabelledImagePool *hierarchy_pool) {
+
+ if(_hierarchy) {
+ cerr << "Can not re-train a Detector" << endl;
+ exit(1);
+ }
+
+ _hierarchy = new PoseCellHierarchy(hierarchy_pool);
+
+ int nb_violations;
+
+ nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
+
+ if(nb_violations > 0) {
+ cout << "The hierarchy is incompatible with the training set ("
+ << nb_violations
+ << " violations)." << endl;
+ exit(1);
+ }
+
+ nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
+
+ if(nb_violations > 0) {
+ cout << "The hierarchy is incompatible with the validation set ("
+ << nb_violations << " violations)."
+ << endl;
+ exit(1);
+ }
+
+ _nb_levels = _hierarchy->nb_levels();
+ _nb_classifiers_per_level = global.nb_classifiers_per_level;
+ _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
+ _thresholds = new scalar_t[_nb_classifiers];
+ _classifiers = new Classifier *[_nb_classifiers];
+ _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
+
+ for(int q = 0; q < _nb_classifiers; q++) {
+ _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
+ _pi_feature_families[q] = new PiFeatureFamily();
+ }
+
+ ParsingPool *train_parsing, *validation_parsing;
+
+ train_parsing = new ParsingPool(train_pool,
+ _hierarchy,
+ global.proportion_negative_cells_for_training);
+
+ if(global.write_validation_rocs) {
+ validation_parsing = new ParsingPool(validation_pool,
+ _hierarchy,
+ global.proportion_negative_cells_for_training);
+ } else {
+ validation_parsing = 0;
+ }
+
+ LossMachine *loss_machine = new LossMachine(global.loss_type);
+
+ cout << "Building a detector." << endl;
+
+ global.bar.init(&cout, _nb_classifiers);
+
+ for(int l = 0; l < _nb_levels; l++) {
+
+ if(l > 0) {
+ train_parsing->down_one_level(loss_machine, _hierarchy, l);
+ if(validation_parsing) {
+ validation_parsing->down_one_level(loss_machine, _hierarchy, l);
+ }
+ }
+
+ for(int c = 0; c < _nb_classifiers_per_level; c++) {
+ int q = l * _nb_classifiers_per_level + c;
+
+ (*global.log_stream) << "Building classifier " << q << " (level " << l << ")" << endl;
+
+ // Train the classifier
+
+ train_classifier(l,
+ loss_machine,
+ train_parsing,
+ _pi_feature_families[q], _classifiers[q]);
+
+ // Update the cell responses on the training set
+
+ (*global.log_stream) << "Updating training cell responses." << endl;
+
+ train_parsing->update_cell_responses(_pi_feature_families[q],
+ _classifiers[q]);
+
+ // Save the ROC curves on the training set
+
+ char buffer[buffer_size];
+
+ sprintf(buffer, "%s/train_%05d.roc",
+ global.result_path,
+ (q + 1) * global.nb_weak_learners_per_classifier);
+ ofstream out(buffer);
+ train_parsing->write_roc(&out);
+
+ if(validation_parsing) {
+
+ // Update the cell responses on the validation set
+
+ (*global.log_stream) << "Updating validation cell responses." << endl;
+
+ validation_parsing->update_cell_responses(_pi_feature_families[q],
+ _classifiers[q]);
+
+ // Save the ROC curves on the validation set
+
+ sprintf(buffer, "%s/validation_%05d.roc",
+ global.result_path,
+ (q + 1) * global.nb_weak_learners_per_classifier);
+ ofstream out(buffer);
+ validation_parsing->write_roc(&out);
+ }
+
+ _thresholds[q] = 0.0;
+
+ global.bar.refresh(&cout, q);
+ }
+ }
+
+ global.bar.finish(&cout);
+
+ delete loss_machine;
+ delete train_parsing;
+ delete validation_parsing;
+}
+
+void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
+ LabelledImage *image;
+ int nb_targets_total = 0;
+
+ for(int i = 0; i < validation_pool->nb_images(); i++) {
+ image = validation_pool->grab_image(i);
+ nb_targets_total += image->nb_targets();
+ validation_pool->release_image(i);
+ }
+
+ scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
+
+ int tt = 0;
+
+ for(int i = 0; i < validation_pool->nb_images(); i++) {
+ image = validation_pool->grab_image(i);
+ image->compute_rich_structure();
+
+ PoseCell current_cell;
+
+ for(int t = 0; t < image->nb_targets(); t++) {
+
+ scalar_t response = 0;
+
+ for(int l = 0; l < _nb_levels; l++) {
+
+ // We get the next-level cell for that target
+
+ PoseCellSet cell_set;
+
+ cell_set.erase_content();
+ if(l == 0) {
+ _hierarchy->add_root_cells(image, &cell_set);
+ } else {
+ _hierarchy->add_subcells(l, ¤t_cell, &cell_set);
+ }
+
+ int nb_compliant = 0;
+
+ for(int c = 0; c < cell_set.nb_cells(); c++) {
+ if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
+ current_cell = *(cell_set.get_cell(c));
+ nb_compliant++;
+ }
+ }
+
+ if(nb_compliant != 1) {
+ cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
+ abort();
+ }
+
+ for(int c = 0; c < _nb_classifiers_per_level; c++) {
+ int q = l * _nb_classifiers_per_level + c;
+ SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
+ sample_set->set_sample(0, _pi_feature_families[q], image, ¤t_cell, 0);
+ response +=_classifiers[q]->response(sample_set, 0);
+ delete sample_set;
+ responses[tt + nb_targets_total * q] = response;
+ }
+
+ }
+
+ tt++;
+ }
+
+ validation_pool->release_image(i);
+ }
+
+ ASSERT(tt == nb_targets_total);
+
+ // Here we have in responses[] all the target responses after every
+ // classifier
+
+ int *still_detected = new int[nb_targets_total];
+ int *indexes = new int[nb_targets_total];
+ int *sorted_indexes = new int[nb_targets_total];
+
+ for(int t = 0; t < nb_targets_total; t++) {
+ still_detected[t] = 1;
+ indexes[t] = t;
+ }
+
+ int current_nb_fn = 0;
+
+ for(int q = 0; q < _nb_classifiers; q++) {
+
+ scalar_t wanted_tp_at_this_classifier
+ = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
+
+ int wanted_nb_fn_at_this_classifier
+ = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
+
+ (*global.log_stream) << "q = " << q
+ << " wanted_tp_at_this_classifier = " << wanted_tp_at_this_classifier
+ << " wanted_nb_fn_at_this_classifier = " << wanted_nb_fn_at_this_classifier
+ << endl;
+
+ indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
+ responses + q * nb_targets_total);
+
+ for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
+ int u = sorted_indexes[t];
+ int v = sorted_indexes[t+1];
+ _thresholds[q] = responses[v + nb_targets_total * q];
+ if(still_detected[u]) {
+ still_detected[u] = 0;
+ current_nb_fn++;
+ }
+ }
+ }
+
+ delete[] still_detected;
+ delete[] indexes;
+ delete[] sorted_indexes;
+
+ { ////////////////////////////////////////////////////////////////////
+ // Sanity check
+
+ int nb_positives = 0;
+
+ for(int t = 0; t < nb_targets_total; t++) {
+ int positive = 1;
+ for(int q = 0; q < _nb_classifiers; q++) {
+ if(responses[t + nb_targets_total * q] < _thresholds[q]) positive = 0;
+ }
+ if(positive) nb_positives++;
+ }
+
+ scalar_t actual_tp = scalar_t(nb_positives) / scalar_t(nb_targets_total);
+
+ (*global.log_stream) << "Overall detection rate " << nb_positives << "/" << nb_targets_total
+ << " "
+ << "actual_tp = " << actual_tp
+ << " "
+ << "wanted_tp = " << wanted_tp
+ << endl;
+
+ if(actual_tp < wanted_tp) {
+ cerr << "INCONSISTENCY" << endl;
+ abort();
+ }
+ } ////////////////////////////////////////////////////////////////////
+
+ delete[] responses;
+}
+
+//////////////////////////////////////////////////////////////////////
+// Parsing
+
+void Detector::parse_rec(RichImage *image, int level,
+ PoseCell *cell, scalar_t current_response,
+ PoseCellScoredSet *result) {
+
+ if(level == _nb_levels) {
+ result->add_cell_with_score(cell, current_response);
+ return;
+ }
+
+ PoseCellSet cell_set;
+ cell_set.erase_content();
+
+ if(level == 0) {
+ _hierarchy->add_root_cells(image, &cell_set);
+ } else {
+ _hierarchy->add_subcells(level, cell, &cell_set);
+ }
+
+ scalar_t *responses = new scalar_t[cell_set.nb_cells()];
+ int *keep = new int[cell_set.nb_cells()];
+
+ for(int c = 0; c < cell_set.nb_cells(); c++) {
+ responses[c] = current_response;
+ keep[c] = 1;
+ }
+
+ for(int a = 0; a < _nb_classifiers_per_level; a++) {
+ int q = level * _nb_classifiers_per_level + a;
+ SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
+ for(int c = 0; c < cell_set.nb_cells(); c++) {
+ if(keep[c]) {
+ samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
+ responses[c] += _classifiers[q]->response(samples, 0);
+ keep[c] = responses[c] >= _thresholds[q];
+ }
+ }
+ delete samples;
+ }
+
+ for(int c = 0; c < cell_set.nb_cells(); c++) {
+ if(keep[c]) {
+ parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
+ }
+ }
+
+ delete[] keep;
+ delete[] responses;
+}
+
+void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
+ result_cell_set->erase_content();
+ parse_rec(image, 0, 0, 0, result_cell_set);
+}
+
+//////////////////////////////////////////////////////////////////////
+// Storage
+
+void Detector::read(istream *is) {
+ if(_hierarchy) {
+ cerr << "Can not read over an existing Detector" << endl;
+ exit(1);
+ }
+
+ read_var(is, &_nb_levels);
+ read_var(is, &_nb_classifiers_per_level);
+
+ _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
+
+ _classifiers = new Classifier *[_nb_classifiers];
+ _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
+ _thresholds = new scalar_t[_nb_classifiers];
+
+ for(int q = 0; q < _nb_classifiers; q++) {
+ cout << "Read classifier " << q << endl;
+ _pi_feature_families[q] = new PiFeatureFamily();
+ _pi_feature_families[q]->read(is);
+ _classifiers[q] = read_classifier(is);
+ read_var(is, &_thresholds[q]);
+ }
+
+ _hierarchy = read_hierarchy(is);
+
+ (*global.log_stream) << "Read Detector" << endl
+ << " _nb_levels " << _nb_levels << endl
+ << " _nb_classifiers_per_level " << _nb_classifiers_per_level << endl;
+}
+
+void Detector::write(ostream *os) {
+ write_var(os, &_nb_levels);
+ write_var(os, &_nb_classifiers_per_level);
+
+ for(int q = 0; q < _nb_classifiers; q++) {
+ _pi_feature_families[q]->write(os);
+ _classifiers[q]->write(os);
+ write_var(os, &_thresholds[q]);
+ }
+
+ _hierarchy->write(os);
+}