automatic commit
[folded-ctf.git] / detector.cc
diff --git a/detector.cc b/detector.cc
new file mode 100644 (file)
index 0000000..1c3ef23
--- /dev/null
@@ -0,0 +1,474 @@
+
+///////////////////////////////////////////////////////////////////////////
+// This program is free software: you can redistribute it and/or modify  //
+// it under the terms of the version 3 of the GNU General Public License //
+// as published by the Free Software Foundation.                         //
+//                                                                       //
+// This program is distributed in the hope that it will be useful, but   //
+// WITHOUT ANY WARRANTY; without even the implied warranty of            //
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
+// General Public License for more details.                              //
+//                                                                       //
+// You should have received a copy of the GNU General Public License     //
+// along with this program. If not, see <http://www.gnu.org/licenses/>.  //
+//                                                                       //
+// Written by Francois Fleuret, (C) IDIAP                                //
+// Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
+///////////////////////////////////////////////////////////////////////////
+
+#include "tools.h"
+#include "detector.h"
+#include "global.h"
+#include "classifier_reader.h"
+#include "pose_cell_hierarchy_reader.h"
+
+Detector::Detector() {
+  _hierarchy = 0;
+  _nb_levels = 0;
+  _nb_classifiers_per_level = 0;
+  _thresholds = 0;
+  _nb_classifiers = 0;
+  _classifiers = 0;
+  _pi_feature_families = 0;
+}
+
+
+Detector::~Detector() {
+  if(_hierarchy) {
+    delete[] _thresholds;
+    for(int q = 0; q < _nb_classifiers; q++) {
+      delete _classifiers[q];
+      delete _pi_feature_families[q];
+    }
+    delete[] _classifiers;
+    delete[] _pi_feature_families;
+    delete _hierarchy;
+  }
+}
+
+//////////////////////////////////////////////////////////////////////
+// Training
+
+void Detector::train_classifier(int level,
+                                LossMachine *loss_machine,
+                                ParsingPool *parsing_pool,
+                                PiFeatureFamily *pi_feature_family,
+                                Classifier *classifier) {
+
+  // Randomize the pi-feature family
+
+  PiFeatureFamily full_pi_feature_family;
+
+  full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
+  full_pi_feature_family.randomize(level);
+
+  int nb_positives = parsing_pool->nb_positive_cells();
+
+  int nb_negatives_to_sample =
+    parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
+
+  SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
+                                        nb_positives + nb_negatives_to_sample);
+
+  scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
+
+  (*global.log_stream) << "Collecting the sampled training set." << endl;
+
+  parsing_pool->weighted_sampling(loss_machine,
+                                  &full_pi_feature_family,
+                                  sample_set,
+                                  responses);
+
+  (*global.log_stream) << "Training the classifier." << endl;
+
+  (*global.log_stream) << "Initial train_loss "
+                       << loss_machine->loss(sample_set, responses)
+                       << endl;
+
+  classifier->train(loss_machine, sample_set, responses);
+  classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
+
+  delete[] responses;
+  delete sample_set;
+}
+
+void Detector::train(LabelledImagePool *train_pool,
+                     LabelledImagePool *validation_pool,
+                     LabelledImagePool *hierarchy_pool) {
+
+  if(_hierarchy) {
+    cerr << "Can not re-train a Detector" << endl;
+    exit(1);
+  }
+
+  _hierarchy = new PoseCellHierarchy(hierarchy_pool);
+
+  int nb_violations;
+
+  nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
+
+  if(nb_violations > 0) {
+    cout << "The hierarchy is incompatible with the training set ("
+         << nb_violations
+         << " violations)." << endl;
+    exit(1);
+  }
+
+  nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
+
+  if(nb_violations > 0) {
+    cout << "The hierarchy is incompatible with the validation set ("
+         << nb_violations << " violations)."
+         << endl;
+    exit(1);
+  }
+
+  _nb_levels = _hierarchy->nb_levels();
+  _nb_classifiers_per_level = global.nb_classifiers_per_level;
+  _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
+  _thresholds = new scalar_t[_nb_classifiers];
+  _classifiers = new Classifier *[_nb_classifiers];
+  _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
+
+  for(int q = 0; q < _nb_classifiers; q++) {
+    _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
+    _pi_feature_families[q] = new PiFeatureFamily();
+  }
+
+  ParsingPool *train_parsing, *validation_parsing;
+
+  train_parsing = new ParsingPool(train_pool,
+                                  _hierarchy,
+                                  global.proportion_negative_cells_for_training);
+
+  if(global.write_validation_rocs) {
+    validation_parsing = new ParsingPool(validation_pool,
+                                         _hierarchy,
+                                         global.proportion_negative_cells_for_training);
+  } else {
+    validation_parsing = 0;
+  }
+
+  LossMachine *loss_machine = new LossMachine(global.loss_type);
+
+  cout << "Building a detector." << endl;
+
+  global.bar.init(&cout, _nb_classifiers);
+
+  for(int l = 0; l < _nb_levels; l++) {
+
+    if(l > 0) {
+      train_parsing->down_one_level(loss_machine, _hierarchy, l);
+      if(validation_parsing) {
+        validation_parsing->down_one_level(loss_machine, _hierarchy, l);
+      }
+    }
+
+    for(int c = 0; c < _nb_classifiers_per_level; c++) {
+      int q = l * _nb_classifiers_per_level + c;
+
+      (*global.log_stream) << "Building classifier " << q << " (level " << l << ")" << endl;
+
+      // Train the classifier
+
+      train_classifier(l,
+                       loss_machine,
+                       train_parsing,
+                       _pi_feature_families[q], _classifiers[q]);
+
+      // Update the cell responses on the training set
+
+      (*global.log_stream) << "Updating training cell responses." << endl;
+
+      train_parsing->update_cell_responses(_pi_feature_families[q],
+                                           _classifiers[q]);
+
+      // Save the ROC curves on the training set
+
+      char buffer[buffer_size];
+
+      sprintf(buffer, "%s/train_%05d.roc",
+              global.result_path,
+              (q + 1) * global.nb_weak_learners_per_classifier);
+      ofstream out(buffer);
+      train_parsing->write_roc(&out);
+
+      if(validation_parsing) {
+
+        // Update the cell responses on the validation set
+
+        (*global.log_stream) << "Updating validation cell responses." << endl;
+
+        validation_parsing->update_cell_responses(_pi_feature_families[q],
+                                                  _classifiers[q]);
+
+        // Save the ROC curves on the validation set
+
+        sprintf(buffer, "%s/validation_%05d.roc",
+                global.result_path,
+                (q + 1) * global.nb_weak_learners_per_classifier);
+        ofstream out(buffer);
+        validation_parsing->write_roc(&out);
+      }
+
+      _thresholds[q] = 0.0;
+
+      global.bar.refresh(&cout, q);
+    }
+  }
+
+  global.bar.finish(&cout);
+
+  delete loss_machine;
+  delete train_parsing;
+  delete validation_parsing;
+}
+
+void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
+  LabelledImage *image;
+  int nb_targets_total = 0;
+
+  for(int i = 0; i < validation_pool->nb_images(); i++) {
+    image = validation_pool->grab_image(i);
+    nb_targets_total += image->nb_targets();
+    validation_pool->release_image(i);
+  }
+
+  scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
+
+  int tt = 0;
+
+  for(int i = 0; i < validation_pool->nb_images(); i++) {
+    image = validation_pool->grab_image(i);
+    image->compute_rich_structure();
+
+    PoseCell current_cell;
+
+    for(int t = 0; t < image->nb_targets(); t++) {
+
+      scalar_t response = 0;
+
+      for(int l = 0; l < _nb_levels; l++) {
+
+        // We get the next-level cell for that target
+
+        PoseCellSet cell_set;
+
+        cell_set.erase_content();
+        if(l == 0) {
+          _hierarchy->add_root_cells(image, &cell_set);
+        } else {
+          _hierarchy->add_subcells(l, &current_cell, &cell_set);
+        }
+
+        int nb_compliant = 0;
+
+        for(int c = 0; c < cell_set.nb_cells(); c++) {
+          if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
+            current_cell = *(cell_set.get_cell(c));
+            nb_compliant++;
+          }
+        }
+
+        if(nb_compliant != 1) {
+          cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
+          abort();
+        }
+
+        for(int c = 0; c < _nb_classifiers_per_level; c++) {
+          int q = l * _nb_classifiers_per_level + c;
+          SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
+          sample_set->set_sample(0, _pi_feature_families[q], image, &current_cell, 0);
+          response +=_classifiers[q]->response(sample_set, 0);
+          delete sample_set;
+          responses[tt + nb_targets_total * q] = response;
+        }
+
+      }
+
+      tt++;
+    }
+
+    validation_pool->release_image(i);
+  }
+
+  ASSERT(tt == nb_targets_total);
+
+  // Here we have in responses[] all the target responses after every
+  // classifier
+
+  int *still_detected = new int[nb_targets_total];
+  int *indexes = new int[nb_targets_total];
+  int *sorted_indexes = new int[nb_targets_total];
+
+  for(int t = 0; t < nb_targets_total; t++) {
+    still_detected[t] = 1;
+    indexes[t] = t;
+  }
+
+  int current_nb_fn = 0;
+
+  for(int q = 0; q < _nb_classifiers; q++) {
+
+    scalar_t wanted_tp_at_this_classifier
+      = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
+
+    int wanted_nb_fn_at_this_classifier
+      = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
+
+    (*global.log_stream) << "q = " << q
+                         << " wanted_tp_at_this_classifier = " << wanted_tp_at_this_classifier
+                         << " wanted_nb_fn_at_this_classifier = " << wanted_nb_fn_at_this_classifier
+                         << endl;
+
+    indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
+                        responses + q * nb_targets_total);
+
+    for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
+      int u = sorted_indexes[t];
+      int v = sorted_indexes[t+1];
+      _thresholds[q] = responses[v + nb_targets_total * q];
+      if(still_detected[u]) {
+        still_detected[u] = 0;
+        current_nb_fn++;
+      }
+    }
+  }
+
+  delete[] still_detected;
+  delete[] indexes;
+  delete[] sorted_indexes;
+
+  { ////////////////////////////////////////////////////////////////////
+    // Sanity check
+
+    int nb_positives = 0;
+
+    for(int t = 0; t < nb_targets_total; t++) {
+      int positive = 1;
+      for(int q = 0; q < _nb_classifiers; q++) {
+        if(responses[t + nb_targets_total * q] < _thresholds[q]) positive = 0;
+      }
+      if(positive) nb_positives++;
+    }
+
+    scalar_t actual_tp = scalar_t(nb_positives) / scalar_t(nb_targets_total);
+
+    (*global.log_stream) << "Overall detection rate " << nb_positives << "/" << nb_targets_total
+                         << " "
+                         << "actual_tp = " << actual_tp
+                         << " "
+                         << "wanted_tp = " << wanted_tp
+                         << endl;
+
+    if(actual_tp < wanted_tp) {
+      cerr << "INCONSISTENCY" << endl;
+      abort();
+    }
+  } ////////////////////////////////////////////////////////////////////
+
+  delete[] responses;
+}
+
+//////////////////////////////////////////////////////////////////////
+// Parsing
+
+void Detector::parse_rec(RichImage *image, int level,
+                         PoseCell *cell, scalar_t current_response,
+                         PoseCellScoredSet *result) {
+
+  if(level == _nb_levels) {
+    result->add_cell_with_score(cell, current_response);
+    return;
+  }
+
+  PoseCellSet cell_set;
+  cell_set.erase_content();
+
+  if(level == 0) {
+    _hierarchy->add_root_cells(image, &cell_set);
+  } else {
+    _hierarchy->add_subcells(level, cell, &cell_set);
+  }
+
+  scalar_t *responses = new scalar_t[cell_set.nb_cells()];
+  int *keep = new int[cell_set.nb_cells()];
+
+  for(int c = 0; c < cell_set.nb_cells(); c++) {
+    responses[c] = current_response;
+    keep[c] = 1;
+  }
+
+  for(int a = 0; a < _nb_classifiers_per_level; a++) {
+    int q = level * _nb_classifiers_per_level + a;
+    SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
+    for(int c = 0; c < cell_set.nb_cells(); c++) {
+      if(keep[c]) {
+        samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
+        responses[c] += _classifiers[q]->response(samples, 0);
+        keep[c] = responses[c] >= _thresholds[q];
+      }
+    }
+    delete samples;
+  }
+
+  for(int c = 0; c < cell_set.nb_cells(); c++) {
+    if(keep[c]) {
+      parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
+    }
+  }
+
+  delete[] keep;
+  delete[] responses;
+}
+
+void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
+  result_cell_set->erase_content();
+  parse_rec(image, 0, 0, 0, result_cell_set);
+}
+
+//////////////////////////////////////////////////////////////////////
+// Storage
+
+void Detector::read(istream *is) {
+  if(_hierarchy) {
+    cerr << "Can not read over an existing Detector" << endl;
+    exit(1);
+  }
+
+  read_var(is, &_nb_levels);
+  read_var(is, &_nb_classifiers_per_level);
+
+  _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
+
+  _classifiers = new Classifier *[_nb_classifiers];
+  _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
+  _thresholds = new scalar_t[_nb_classifiers];
+
+  for(int q = 0; q < _nb_classifiers; q++) {
+    cout << "Read classifier " << q << endl;
+    _pi_feature_families[q] = new PiFeatureFamily();
+    _pi_feature_families[q]->read(is);
+    _classifiers[q] = read_classifier(is);
+    read_var(is, &_thresholds[q]);
+  }
+
+  _hierarchy = read_hierarchy(is);
+
+  (*global.log_stream) << "Read Detector" << endl
+                       << "  _nb_levels " << _nb_levels << endl
+                       << "  _nb_classifiers_per_level " << _nb_classifiers_per_level << endl;
+}
+
+void Detector::write(ostream *os) {
+  write_var(os, &_nb_levels);
+  write_var(os, &_nb_classifiers_per_level);
+
+  for(int q = 0; q < _nb_classifiers; q++) {
+    _pi_feature_families[q]->write(os);
+    _classifiers[q]->write(os);
+    write_var(os, &_thresholds[q]);
+  }
+
+  _hierarchy->write(os);
+}