automatic commit
[folded-ctf.git] / decision_tree.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include "decision_tree.h"
22 #include "fusion_sort.h"
23
24 DecisionTree::DecisionTree() {
25   _feature_index = -1;
26   _threshold = 0;
27   _weight = 0;
28   _subtree_greater = 0;
29   _subtree_lesser = 0;
30 }
31
32 DecisionTree::~DecisionTree() {
33   if(_subtree_lesser)
34     delete _subtree_lesser;
35   if(_subtree_greater)
36     delete _subtree_greater;
37 }
38
39 int DecisionTree::nb_leaves() {
40   if(_subtree_lesser ||_subtree_greater)
41     return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
42   else
43     return 1;
44 }
45
46 int DecisionTree::depth() {
47   if(_subtree_lesser ||_subtree_greater)
48     return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
49   else
50     return 1;
51 }
52
53 scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
54   if(_subtree_lesser && _subtree_greater) {
55     if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
56       return _subtree_lesser->response(sample_set, n_sample);
57     else
58       return _subtree_greater->response(sample_set, n_sample);
59   } else {
60     return _weight;
61   }
62 }
63
64 void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
65
66   int nb_samples = sample_set->nb_samples();
67
68   scalar_t *responses = new scalar_t[nb_samples];
69   int *indexes = new int[nb_samples];
70   int *sorted_indexes = new int[nb_samples];
71
72   scalar_t max_abs_sum = 0;
73   _feature_index = -1;
74
75   for(int f = 0; f < sample_set->nb_features(); f++) {
76     scalar_t sum = 0;
77
78     for(int s = 0; s < nb_samples; s++) {
79       indexes[s] = s;
80       responses[s] = sample_set->feature_value(s, f);
81       sum += loss_derivatives[s];
82     }
83
84     indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
85
86     int t, u = sorted_indexes[0];
87     for(int s = 0; s < nb_samples - 1; s++) {
88       t = u;
89       u = sorted_indexes[s + 1];
90       sum -= 2 * loss_derivatives[t];
91
92       if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
93         max_abs_sum = abs(sum);
94         _feature_index = f;
95         _threshold = (responses[t] + responses[u])/2;
96       }
97     }
98   }
99
100   delete[] indexes;
101   delete[] sorted_indexes;
102   delete[] responses;
103 }
104
105 void DecisionTree::train(LossMachine *loss_machine,
106                          SampleSet *sample_set,
107                          scalar_t *current_responses,
108                          scalar_t *loss_derivatives,
109                          int depth) {
110
111   if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
112     cerr << "You can not re-train a tree." << endl;
113     abort();
114   }
115
116   int nb_samples = sample_set->nb_samples();
117
118   int nb_pos = 0, nb_neg = 0;
119   for(int s = 0; s < sample_set->nb_samples(); s++) {
120     if(sample_set->label(s) > 0) nb_pos++;
121     else if(sample_set->label(s) < 0) nb_neg++;
122   }
123
124   (*global.log_stream) << "Training tree" << endl;
125   (*global.log_stream) << "  nb_samples " << nb_samples << endl;
126   (*global.log_stream) << "  depth " << depth << endl;
127   (*global.log_stream) << "  nb_pos = " << nb_pos << endl;
128   (*global.log_stream) << "  nb_neg = " << nb_neg << endl;
129
130   if(depth >= global.tree_depth_max)
131     (*global.log_stream) << "  Maximum depth reached." << endl;
132   if(nb_pos < min_nb_samples_for_split)
133     (*global.log_stream) << "  Not enough positive samples." << endl;
134   if(nb_neg < min_nb_samples_for_split)
135     (*global.log_stream) << "  Not enough negative samples." << endl;
136
137   if(depth < global.tree_depth_max &&
138      nb_pos >= min_nb_samples_for_split &&
139      nb_neg >= min_nb_samples_for_split) {
140
141     pick_best_split(sample_set, loss_derivatives);
142
143     if(_feature_index >= 0) {
144       int indexes[nb_samples];
145       scalar_t *parted_current_responses = new scalar_t[nb_samples];
146       scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
147
148       int nb_lesser = 0, nb_greater = 0;
149       int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
150
151       for(int s = 0; s < nb_samples; s++) {
152         if(sample_set->feature_value(s, _feature_index) < _threshold) {
153           indexes[nb_lesser] = s;
154           parted_current_responses[nb_lesser] = current_responses[s];
155           parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
156
157           if(sample_set->label(s) > 0)
158             nb_lesser_pos++;
159           else if(sample_set->label(s) < 0)
160             nb_lesser_neg++;
161
162           nb_lesser++;
163         } else {
164           nb_greater++;
165
166           indexes[nb_samples - nb_greater] = s;
167           parted_current_responses[nb_samples - nb_greater] = current_responses[s];
168           parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
169
170           if(sample_set->label(s) > 0)
171             nb_greater_pos++;
172           else if(sample_set->label(s) < 0)
173             nb_greater_neg++;
174         }
175       }
176
177       if((nb_lesser_pos >= min_nb_samples_for_split ||
178           nb_lesser_neg >= min_nb_samples_for_split) &&
179          (nb_greater_pos >= min_nb_samples_for_split ||
180           nb_greater_neg >= min_nb_samples_for_split)) {
181
182         _subtree_lesser = new DecisionTree();
183
184         {
185           SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
186
187           _subtree_lesser->train(loss_machine,
188                                  &sub_sample_set,
189                                  parted_current_responses,
190                                  parted_loss_derivatives,
191                                  depth + 1);
192         }
193
194         _subtree_greater = new DecisionTree();
195
196         {
197           SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
198
199           _subtree_greater->train(loss_machine,
200                                   &sub_sample_set,
201                                   parted_current_responses + nb_lesser,
202                                   parted_loss_derivatives + nb_lesser,
203                                   depth + 1);
204         }
205       }
206
207       delete[] parted_current_responses;
208       delete[] parted_loss_derivatives;
209     } else {
210       (*global.log_stream) << "Could not find a feature for split." << endl;
211     }
212   }
213
214   if(!(_subtree_greater && _subtree_lesser)) {
215     scalar_t *tmp_responses = new scalar_t[nb_samples];
216     for(int s = 0; s < nb_samples; s++)
217       tmp_responses[s] = 1;
218
219     _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
220
221     const scalar_t max_weight = 10.0;
222
223     if(_weight > max_weight) {
224       _weight = max_weight;
225     } else if(_weight < - max_weight) {
226       _weight = - max_weight;
227     }
228
229     (*global.log_stream) << "  _weight " << _weight << endl;
230
231     delete[] tmp_responses;
232   }
233 }
234
235 void DecisionTree::train(LossMachine *loss_machine,
236                  SampleSet *sample_set,
237                  scalar_t *current_responses) {
238
239   scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
240
241   loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
242
243   train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
244
245   delete[] loss_derivatives;
246 }
247
248 //////////////////////////////////////////////////////////////////////
249
250 void DecisionTree::tag_used_features(bool *used) {
251   if(_subtree_lesser && _subtree_greater) {
252     used[_feature_index] = true;
253     _subtree_lesser->tag_used_features(used);
254     _subtree_greater->tag_used_features(used);
255   }
256 }
257
258 void DecisionTree::re_index_features(int *new_indexes) {
259   if(_subtree_lesser && _subtree_greater) {
260     _feature_index = new_indexes[_feature_index];
261     _subtree_lesser->re_index_features(new_indexes);
262     _subtree_greater->re_index_features(new_indexes);
263   }
264 }
265
266 //////////////////////////////////////////////////////////////////////
267
268 void DecisionTree::read(istream *is) {
269   if(_subtree_lesser || _subtree_greater) {
270     cerr << "You can not read in an existing tree." << endl;
271     abort();
272   }
273
274   read_var(is, &_feature_index);
275   read_var(is, &_threshold);
276   read_var(is, &_weight);
277
278   int split;
279   read_var(is, &split);
280
281   if(split) {
282     _subtree_lesser = new DecisionTree();
283     _subtree_lesser->read(is);
284     _subtree_greater = new DecisionTree();
285     _subtree_greater->read(is);
286   }
287 }
288
289 void DecisionTree::write(ostream *os) {
290
291   write_var(os, &_feature_index);
292   write_var(os, &_threshold);
293   write_var(os, &_weight);
294
295   int split;
296   if(_subtree_lesser && _subtree_greater) {
297     split = 1;
298     write_var(os, &split);
299     _subtree_lesser->write(os);
300     _subtree_greater->write(os);
301   } else {
302     split = 0;
303     write_var(os, &split);
304   }
305 }