automatic commit
[folded-ctf.git] / decision_tree.h
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 /*
22
23   An implementation of the classifier with a decision tree. Each node
24   simply thresholds one of the component, and is chosen for maximum
25   loss reduction locally during training. The leaves are labelled with
26   the classifier response, which is chosen again for maximum loss
27   reduction.
28
29  */
30
31 #ifndef DECISION_TREE_H
32 #define DECISION_TREE_H
33
34 #include "misc.h"
35 #include "classifier.h"
36 #include "sample_set.h"
37 #include "loss_machine.h"
38
39 class DecisionTree : public Classifier {
40
41   static const int min_nb_samples_for_split = 5;
42
43   int _feature_index;
44   scalar_t _threshold;
45   scalar_t _weight;
46
47   DecisionTree *_subtree_lesser, *_subtree_greater;
48
49   void pick_best_split(SampleSet *sample_set,
50                        scalar_t *loss_derivatives);
51
52   void train(LossMachine *loss_machine,
53              SampleSet *sample_set,
54              scalar_t *current_responses,
55              scalar_t *loss_derivatives,
56              int depth);
57
58 public:
59
60   DecisionTree();
61   ~DecisionTree();
62
63   int nb_leaves();
64   int depth();
65
66   scalar_t response(SampleSet *sample_set, int n_sample);
67
68   void train(LossMachine *loss_machine,
69              SampleSet *sample_set,
70              scalar_t *current_responses);
71
72   void tag_used_features(bool *used);
73   void re_index_features(int *new_indexes);
74
75   void read(istream *is);
76   void write(ostream *os);
77 };
78
79 #endif