automatic commit
[folded-ctf.git] / boosted_classifier.cc
diff --git a/boosted_classifier.cc b/boosted_classifier.cc
new file mode 100644 (file)
index 0000000..8080ad2
--- /dev/null
@@ -0,0 +1,130 @@
+
+///////////////////////////////////////////////////////////////////////////
+// This program is free software: you can redistribute it and/or modify  //
+// it under the terms of the version 3 of the GNU General Public License //
+// as published by the Free Software Foundation.                         //
+//                                                                       //
+// This program is distributed in the hope that it will be useful, but   //
+// WITHOUT ANY WARRANTY; without even the implied warranty of            //
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
+// General Public License for more details.                              //
+//                                                                       //
+// You should have received a copy of the GNU General Public License     //
+// along with this program. If not, see <http://www.gnu.org/licenses/>.  //
+//                                                                       //
+// Written by Francois Fleuret, (C) IDIAP                                //
+// Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
+///////////////////////////////////////////////////////////////////////////
+
+#include "classifier_reader.h"
+#include "fusion_sort.h"
+
+#include "boosted_classifier.h"
+#include "tools.h"
+
+BoostedClassifier::BoostedClassifier(int nb_weak_learners) {
+  _loss_type = global.loss_type;
+  _nb_weak_learners = nb_weak_learners;
+  _weak_learners = 0;
+}
+
+BoostedClassifier::BoostedClassifier() {
+  _loss_type = global.loss_type;
+  _nb_weak_learners = 0;
+  _weak_learners = 0;
+}
+
+BoostedClassifier::~BoostedClassifier() {
+  if(_weak_learners) {
+    for(int w = 0; w < _nb_weak_learners; w++)
+      delete _weak_learners[w];
+    delete[] _weak_learners;
+  }
+}
+
+scalar_t BoostedClassifier::response(SampleSet *sample_set, int n_sample) {
+  scalar_t r = 0;
+  for(int w = 0; w < _nb_weak_learners; w++) {
+    r += _weak_learners[w]->response(sample_set, n_sample);
+    ASSERT(!isnan(r));
+  }
+  return r;
+}
+
+void BoostedClassifier::train(LossMachine *loss_machine,
+                              SampleSet *sample_set, scalar_t *train_responses) {
+
+  if(_weak_learners) {
+    cerr << "Can not re-train a BoostedClassifier" << endl;
+    exit(1);
+  }
+
+  int nb_pos = 0, nb_neg = 0;
+
+  for(int s = 0; s < sample_set->nb_samples(); s++) {
+    if(sample_set->label(s) > 0) nb_pos++;
+    else if(sample_set->label(s) < 0) nb_neg++;
+  }
+
+  _weak_learners = new DecisionTree *[_nb_weak_learners];
+
+  (*global.log_stream) << "With " << nb_pos << " positive and " << nb_neg << " negative samples." << endl;
+
+  for(int w = 0; w  < _nb_weak_learners; w++) {
+
+    _weak_learners[w] = new DecisionTree();
+    _weak_learners[w]->train(loss_machine, sample_set, train_responses);
+
+    for(int n = 0; n < sample_set->nb_samples(); n++)
+      train_responses[n] += _weak_learners[w]->response(sample_set, n);
+
+    (*global.log_stream) << "Weak learner " << w
+         << " depth " << _weak_learners[w]->depth()
+         << " nb_leaves " << _weak_learners[w]->nb_leaves()
+         << " train loss " << loss_machine->loss(sample_set, train_responses)
+         << endl;
+
+  }
+
+  (*global.log_stream) << "Built a classifier with " << _nb_weak_learners << " weak_learners." << endl;
+}
+
+void BoostedClassifier::tag_used_features(bool *used) {
+  for(int w = 0; w < _nb_weak_learners; w++)
+    _weak_learners[w]->tag_used_features(used);
+}
+
+void BoostedClassifier::re_index_features(int *new_indexes) {
+  for(int w = 0; w < _nb_weak_learners; w++)
+    _weak_learners[w]->re_index_features(new_indexes);
+}
+
+void BoostedClassifier::read(istream *is) {
+  if(_weak_learners) {
+    cerr << "Can not read over an existing BoostedClassifier" << endl;
+    exit(1);
+  }
+
+  read_var(is, &_nb_weak_learners);
+  _weak_learners = new DecisionTree *[_nb_weak_learners];
+  for(int w = 0; w < _nb_weak_learners; w++) {
+    _weak_learners[w] = new DecisionTree();
+    _weak_learners[w]->read(is);
+    (*global.log_stream) << "Read tree " << w << " of depth "
+                         << _weak_learners[w]->depth() << " with "
+                         << _weak_learners[w]->nb_leaves() << " leaves." << endl;
+  }
+
+  (*global.log_stream)
+    << "Read BoostedClassifier containing " << _nb_weak_learners << " weak learners." << endl;
+}
+
+void BoostedClassifier::write(ostream *os) {
+  unsigned int id;
+  id = CLASSIFIER_BOOSTED;
+  write_var(os, &id);
+
+  write_var(os, &_nb_weak_learners);
+  for(int w = 0; w < _nb_weak_learners; w++)
+    _weak_learners[w]->write(os);
+}