Removed the definition of basename, which confuses an existing system one.
[folded-ctf.git] / decision_tree.h
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  folded-ctf is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 /*
26
27   An implementation of the classifier with a decision tree. Each node
28   simply thresholds one of the component, and is chosen for maximum
29   loss reduction locally during training. The leaves are labelled with
30   the classifier response, which is chosen again for maximum loss
31   reduction.
32
33  */
34
35 #ifndef DECISION_TREE_H
36 #define DECISION_TREE_H
37
38 #include "misc.h"
39 #include "classifier.h"
40 #include "sample_set.h"
41 #include "loss_machine.h"
42
43 class DecisionTree : public Classifier {
44
45   static const int min_nb_samples_for_split = 5;
46
47   int _feature_index;
48   scalar_t _threshold;
49   scalar_t _weight;
50
51   DecisionTree *_subtree_lesser, *_subtree_greater;
52
53   void pick_best_split(SampleSet *sample_set,
54                        scalar_t *loss_derivatives);
55
56   void train(LossMachine *loss_machine,
57              SampleSet *sample_set,
58              scalar_t *current_responses,
59              scalar_t *loss_derivatives,
60              int depth);
61
62 public:
63
64   DecisionTree();
65   ~DecisionTree();
66
67   int nb_leaves();
68   int depth();
69
70   scalar_t response(SampleSet *sample_set, int n_sample);
71
72   void train(LossMachine *loss_machine,
73              SampleSet *sample_set,
74              scalar_t *current_responses);
75
76   void tag_used_features(bool *used);
77   void re_index_features(int *new_indexes);
78
79   void read(istream *is);
80   void write(ostream *os);
81 };
82
83 #endif