automatic commit
[folded-ctf.git] / boosted_classifier.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include "classifier_reader.h"
22 #include "fusion_sort.h"
23
24 #include "boosted_classifier.h"
25 #include "tools.h"
26
27 BoostedClassifier::BoostedClassifier(int nb_weak_learners) {
28   _loss_type = global.loss_type;
29   _nb_weak_learners = nb_weak_learners;
30   _weak_learners = 0;
31 }
32
33 BoostedClassifier::BoostedClassifier() {
34   _loss_type = global.loss_type;
35   _nb_weak_learners = 0;
36   _weak_learners = 0;
37 }
38
39 BoostedClassifier::~BoostedClassifier() {
40   if(_weak_learners) {
41     for(int w = 0; w < _nb_weak_learners; w++)
42       delete _weak_learners[w];
43     delete[] _weak_learners;
44   }
45 }
46
47 scalar_t BoostedClassifier::response(SampleSet *sample_set, int n_sample) {
48   scalar_t r = 0;
49   for(int w = 0; w < _nb_weak_learners; w++) {
50     r += _weak_learners[w]->response(sample_set, n_sample);
51     ASSERT(!isnan(r));
52   }
53   return r;
54 }
55
56 void BoostedClassifier::train(LossMachine *loss_machine,
57                               SampleSet *sample_set, scalar_t *train_responses) {
58
59   if(_weak_learners) {
60     cerr << "Can not re-train a BoostedClassifier" << endl;
61     exit(1);
62   }
63
64   int nb_pos = 0, nb_neg = 0;
65
66   for(int s = 0; s < sample_set->nb_samples(); s++) {
67     if(sample_set->label(s) > 0) nb_pos++;
68     else if(sample_set->label(s) < 0) nb_neg++;
69   }
70
71   _weak_learners = new DecisionTree *[_nb_weak_learners];
72
73   (*global.log_stream) << "With " << nb_pos << " positive and " << nb_neg << " negative samples." << endl;
74
75   for(int w = 0; w  < _nb_weak_learners; w++) {
76
77     _weak_learners[w] = new DecisionTree();
78     _weak_learners[w]->train(loss_machine, sample_set, train_responses);
79
80     for(int n = 0; n < sample_set->nb_samples(); n++)
81       train_responses[n] += _weak_learners[w]->response(sample_set, n);
82
83     (*global.log_stream) << "Weak learner " << w
84          << " depth " << _weak_learners[w]->depth()
85          << " nb_leaves " << _weak_learners[w]->nb_leaves()
86          << " train loss " << loss_machine->loss(sample_set, train_responses)
87          << endl;
88
89   }
90
91   (*global.log_stream) << "Built a classifier with " << _nb_weak_learners << " weak_learners." << endl;
92 }
93
94 void BoostedClassifier::tag_used_features(bool *used) {
95   for(int w = 0; w < _nb_weak_learners; w++)
96     _weak_learners[w]->tag_used_features(used);
97 }
98
99 void BoostedClassifier::re_index_features(int *new_indexes) {
100   for(int w = 0; w < _nb_weak_learners; w++)
101     _weak_learners[w]->re_index_features(new_indexes);
102 }
103
104 void BoostedClassifier::read(istream *is) {
105   if(_weak_learners) {
106     cerr << "Can not read over an existing BoostedClassifier" << endl;
107     exit(1);
108   }
109
110   read_var(is, &_nb_weak_learners);
111   _weak_learners = new DecisionTree *[_nb_weak_learners];
112   for(int w = 0; w < _nb_weak_learners; w++) {
113     _weak_learners[w] = new DecisionTree();
114     _weak_learners[w]->read(is);
115     (*global.log_stream) << "Read tree " << w << " of depth "
116                          << _weak_learners[w]->depth() << " with "
117                          << _weak_learners[w]->nb_leaves() << " leaves." << endl;
118   }
119
120   (*global.log_stream)
121     << "Read BoostedClassifier containing " << _nb_weak_learners << " weak learners." << endl;
122 }
123
124 void BoostedClassifier::write(ostream *os) {
125   unsigned int id;
126   id = CLASSIFIER_BOOSTED;
127   write_var(os, &id);
128
129   write_var(os, &_nb_weak_learners);
130   for(int w = 0; w < _nb_weak_learners; w++)
131     _weak_learners[w]->write(os);
132 }