Removed the definition of basename, which confuses an existing system one.
[folded-ctf.git] / decision_tree.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  folded-ctf is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "decision_tree.h"
26 #include "fusion_sort.h"
27
28 DecisionTree::DecisionTree() {
29   _feature_index = -1;
30   _threshold = 0;
31   _weight = 0;
32   _subtree_greater = 0;
33   _subtree_lesser = 0;
34 }
35
36 DecisionTree::~DecisionTree() {
37   if(_subtree_lesser)
38     delete _subtree_lesser;
39   if(_subtree_greater)
40     delete _subtree_greater;
41 }
42
43 int DecisionTree::nb_leaves() {
44   if(_subtree_lesser ||_subtree_greater)
45     return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
46   else
47     return 1;
48 }
49
50 int DecisionTree::depth() {
51   if(_subtree_lesser ||_subtree_greater)
52     return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
53   else
54     return 1;
55 }
56
57 scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
58   if(_subtree_lesser && _subtree_greater) {
59     if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
60       return _subtree_lesser->response(sample_set, n_sample);
61     else
62       return _subtree_greater->response(sample_set, n_sample);
63   } else {
64     return _weight;
65   }
66 }
67
68 void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
69
70   int nb_samples = sample_set->nb_samples();
71
72   scalar_t *responses = new scalar_t[nb_samples];
73   int *indexes = new int[nb_samples];
74   int *sorted_indexes = new int[nb_samples];
75
76   scalar_t max_abs_sum = 0;
77   _feature_index = -1;
78
79   for(int f = 0; f < sample_set->nb_features(); f++) {
80     scalar_t sum = 0;
81
82     for(int s = 0; s < nb_samples; s++) {
83       indexes[s] = s;
84       responses[s] = sample_set->feature_value(s, f);
85       sum += loss_derivatives[s];
86     }
87
88     indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
89
90     int t, u = sorted_indexes[0];
91     for(int s = 0; s < nb_samples - 1; s++) {
92       t = u;
93       u = sorted_indexes[s + 1];
94       sum -= 2 * loss_derivatives[t];
95
96       if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
97         max_abs_sum = abs(sum);
98         _feature_index = f;
99         _threshold = (responses[t] + responses[u])/2;
100       }
101     }
102   }
103
104   delete[] indexes;
105   delete[] sorted_indexes;
106   delete[] responses;
107 }
108
109 void DecisionTree::train(LossMachine *loss_machine,
110                          SampleSet *sample_set,
111                          scalar_t *current_responses,
112                          scalar_t *loss_derivatives,
113                          int depth) {
114
115   if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
116     cerr << "You can not re-train a tree." << endl;
117     abort();
118   }
119
120   int nb_samples = sample_set->nb_samples();
121
122   int nb_pos = 0, nb_neg = 0;
123   for(int s = 0; s < sample_set->nb_samples(); s++) {
124     if(sample_set->label(s) > 0) nb_pos++;
125     else if(sample_set->label(s) < 0) nb_neg++;
126   }
127
128   (*global.log_stream) << "Training tree" << endl;
129   (*global.log_stream) << "  nb_samples " << nb_samples << endl;
130   (*global.log_stream) << "  depth " << depth << endl;
131   (*global.log_stream) << "  nb_pos = " << nb_pos << endl;
132   (*global.log_stream) << "  nb_neg = " << nb_neg << endl;
133
134   if(depth >= global.tree_depth_max)
135     (*global.log_stream) << "  Maximum depth reached." << endl;
136   if(nb_pos < min_nb_samples_for_split)
137     (*global.log_stream) << "  Not enough positive samples." << endl;
138   if(nb_neg < min_nb_samples_for_split)
139     (*global.log_stream) << "  Not enough negative samples." << endl;
140
141   if(depth < global.tree_depth_max &&
142      nb_pos >= min_nb_samples_for_split &&
143      nb_neg >= min_nb_samples_for_split) {
144
145     pick_best_split(sample_set, loss_derivatives);
146
147     if(_feature_index >= 0) {
148       int indexes[nb_samples];
149       scalar_t *parted_current_responses = new scalar_t[nb_samples];
150       scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
151
152       int nb_lesser = 0, nb_greater = 0;
153       int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
154
155       for(int s = 0; s < nb_samples; s++) {
156         if(sample_set->feature_value(s, _feature_index) < _threshold) {
157           indexes[nb_lesser] = s;
158           parted_current_responses[nb_lesser] = current_responses[s];
159           parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
160
161           if(sample_set->label(s) > 0)
162             nb_lesser_pos++;
163           else if(sample_set->label(s) < 0)
164             nb_lesser_neg++;
165
166           nb_lesser++;
167         } else {
168           nb_greater++;
169
170           indexes[nb_samples - nb_greater] = s;
171           parted_current_responses[nb_samples - nb_greater] = current_responses[s];
172           parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
173
174           if(sample_set->label(s) > 0)
175             nb_greater_pos++;
176           else if(sample_set->label(s) < 0)
177             nb_greater_neg++;
178         }
179       }
180
181       if((nb_lesser_pos >= min_nb_samples_for_split ||
182           nb_lesser_neg >= min_nb_samples_for_split) &&
183          (nb_greater_pos >= min_nb_samples_for_split ||
184           nb_greater_neg >= min_nb_samples_for_split)) {
185
186         _subtree_lesser = new DecisionTree();
187
188         {
189           SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
190
191           _subtree_lesser->train(loss_machine,
192                                  &sub_sample_set,
193                                  parted_current_responses,
194                                  parted_loss_derivatives,
195                                  depth + 1);
196         }
197
198         _subtree_greater = new DecisionTree();
199
200         {
201           SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
202
203           _subtree_greater->train(loss_machine,
204                                   &sub_sample_set,
205                                   parted_current_responses + nb_lesser,
206                                   parted_loss_derivatives + nb_lesser,
207                                   depth + 1);
208         }
209       }
210
211       delete[] parted_current_responses;
212       delete[] parted_loss_derivatives;
213     } else {
214       (*global.log_stream) << "Could not find a feature for split." << endl;
215     }
216   }
217
218   if(!(_subtree_greater && _subtree_lesser)) {
219     scalar_t *tmp_responses = new scalar_t[nb_samples];
220     for(int s = 0; s < nb_samples; s++)
221       tmp_responses[s] = 1;
222
223     _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
224
225     const scalar_t max_weight = 10.0;
226
227     if(_weight > max_weight) {
228       _weight = max_weight;
229     } else if(_weight < - max_weight) {
230       _weight = - max_weight;
231     }
232
233     (*global.log_stream) << "  _weight " << _weight << endl;
234
235     delete[] tmp_responses;
236   }
237 }
238
239 void DecisionTree::train(LossMachine *loss_machine,
240                  SampleSet *sample_set,
241                  scalar_t *current_responses) {
242
243   scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
244
245   loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
246
247   train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
248
249   delete[] loss_derivatives;
250 }
251
252 //////////////////////////////////////////////////////////////////////
253
254 void DecisionTree::tag_used_features(bool *used) {
255   if(_subtree_lesser && _subtree_greater) {
256     used[_feature_index] = true;
257     _subtree_lesser->tag_used_features(used);
258     _subtree_greater->tag_used_features(used);
259   }
260 }
261
262 void DecisionTree::re_index_features(int *new_indexes) {
263   if(_subtree_lesser && _subtree_greater) {
264     _feature_index = new_indexes[_feature_index];
265     _subtree_lesser->re_index_features(new_indexes);
266     _subtree_greater->re_index_features(new_indexes);
267   }
268 }
269
270 //////////////////////////////////////////////////////////////////////
271
272 void DecisionTree::read(istream *is) {
273   if(_subtree_lesser || _subtree_greater) {
274     cerr << "You can not read in an existing tree." << endl;
275     abort();
276   }
277
278   read_var(is, &_feature_index);
279   read_var(is, &_threshold);
280   read_var(is, &_weight);
281
282   int split;
283   read_var(is, &split);
284
285   if(split) {
286     _subtree_lesser = new DecisionTree();
287     _subtree_lesser->read(is);
288     _subtree_greater = new DecisionTree();
289     _subtree_greater->read(is);
290   }
291 }
292
293 void DecisionTree::write(ostream *os) {
294
295   write_var(os, &_feature_index);
296   write_var(os, &_threshold);
297   write_var(os, &_weight);
298
299   int split;
300   if(_subtree_lesser && _subtree_greater) {
301     split = 1;
302     write_var(os, &split);
303     _subtree_lesser->write(os);
304     _subtree_greater->write(os);
305   } else {
306     split = 0;
307     write_var(os, &split);
308   }
309 }