X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=boosted_classifier.cc;fp=boosted_classifier.cc;h=8080ad2eada357cd67662490e8dec204d81d27c5;hb=d922ad61d35e9a6996730bec24b16f8bf7bc426c;hp=0000000000000000000000000000000000000000;hpb=3bb118f5a9462d02ff7d99ef28ecc0d7e23529f9;p=folded-ctf.git diff --git a/boosted_classifier.cc b/boosted_classifier.cc new file mode 100644 index 0000000..8080ad2 --- /dev/null +++ b/boosted_classifier.cc @@ -0,0 +1,130 @@ + +/////////////////////////////////////////////////////////////////////////// +// This program is free software: you can redistribute it and/or modify // +// it under the terms of the version 3 of the GNU General Public License // +// as published by the Free Software Foundation. // +// // +// This program is distributed in the hope that it will be useful, but // +// WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // +// General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program. If not, see . // +// // +// Written by Francois Fleuret, (C) IDIAP // +// Contact for comments & bug reports // +/////////////////////////////////////////////////////////////////////////// + +#include "classifier_reader.h" +#include "fusion_sort.h" + +#include "boosted_classifier.h" +#include "tools.h" + +BoostedClassifier::BoostedClassifier(int nb_weak_learners) { + _loss_type = global.loss_type; + _nb_weak_learners = nb_weak_learners; + _weak_learners = 0; +} + +BoostedClassifier::BoostedClassifier() { + _loss_type = global.loss_type; + _nb_weak_learners = 0; + _weak_learners = 0; +} + +BoostedClassifier::~BoostedClassifier() { + if(_weak_learners) { + for(int w = 0; w < _nb_weak_learners; w++) + delete _weak_learners[w]; + delete[] _weak_learners; + } +} + +scalar_t BoostedClassifier::response(SampleSet *sample_set, int n_sample) { + scalar_t r = 0; + for(int w = 0; w < _nb_weak_learners; w++) { + r += _weak_learners[w]->response(sample_set, n_sample); + ASSERT(!isnan(r)); + } + return r; +} + +void BoostedClassifier::train(LossMachine *loss_machine, + SampleSet *sample_set, scalar_t *train_responses) { + + if(_weak_learners) { + cerr << "Can not re-train a BoostedClassifier" << endl; + exit(1); + } + + int nb_pos = 0, nb_neg = 0; + + for(int s = 0; s < sample_set->nb_samples(); s++) { + if(sample_set->label(s) > 0) nb_pos++; + else if(sample_set->label(s) < 0) nb_neg++; + } + + _weak_learners = new DecisionTree *[_nb_weak_learners]; + + (*global.log_stream) << "With " << nb_pos << " positive and " << nb_neg << " negative samples." << endl; + + for(int w = 0; w < _nb_weak_learners; w++) { + + _weak_learners[w] = new DecisionTree(); + _weak_learners[w]->train(loss_machine, sample_set, train_responses); + + for(int n = 0; n < sample_set->nb_samples(); n++) + train_responses[n] += _weak_learners[w]->response(sample_set, n); + + (*global.log_stream) << "Weak learner " << w + << " depth " << _weak_learners[w]->depth() + << " nb_leaves " << _weak_learners[w]->nb_leaves() + << " train loss " << loss_machine->loss(sample_set, train_responses) + << endl; + + } + + (*global.log_stream) << "Built a classifier with " << _nb_weak_learners << " weak_learners." << endl; +} + +void BoostedClassifier::tag_used_features(bool *used) { + for(int w = 0; w < _nb_weak_learners; w++) + _weak_learners[w]->tag_used_features(used); +} + +void BoostedClassifier::re_index_features(int *new_indexes) { + for(int w = 0; w < _nb_weak_learners; w++) + _weak_learners[w]->re_index_features(new_indexes); +} + +void BoostedClassifier::read(istream *is) { + if(_weak_learners) { + cerr << "Can not read over an existing BoostedClassifier" << endl; + exit(1); + } + + read_var(is, &_nb_weak_learners); + _weak_learners = new DecisionTree *[_nb_weak_learners]; + for(int w = 0; w < _nb_weak_learners; w++) { + _weak_learners[w] = new DecisionTree(); + _weak_learners[w]->read(is); + (*global.log_stream) << "Read tree " << w << " of depth " + << _weak_learners[w]->depth() << " with " + << _weak_learners[w]->nb_leaves() << " leaves." << endl; + } + + (*global.log_stream) + << "Read BoostedClassifier containing " << _nb_weak_learners << " weak learners." << endl; +} + +void BoostedClassifier::write(ostream *os) { + unsigned int id; + id = CLASSIFIER_BOOSTED; + write_var(os, &id); + + write_var(os, &_nb_weak_learners); + for(int w = 0; w < _nb_weak_learners; w++) + _weak_learners[w]->write(os); +}