automatic commit
[folded-ctf.git] / decision_tree.cc
diff --git a/decision_tree.cc b/decision_tree.cc
new file mode 100644 (file)
index 0000000..e2b3daa
--- /dev/null
@@ -0,0 +1,303 @@
+
+///////////////////////////////////////////////////////////////////////////
+// This program is free software: you can redistribute it and/or modify  //
+// it under the terms of the version 3 of the GNU General Public License //
+// as published by the Free Software Foundation.                         //
+//                                                                       //
+// This program is distributed in the hope that it will be useful, but   //
+// WITHOUT ANY WARRANTY; without even the implied warranty of            //
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
+// General Public License for more details.                              //
+//                                                                       //
+// You should have received a copy of the GNU General Public License     //
+// along with this program. If not, see <http://www.gnu.org/licenses/>.  //
+//                                                                       //
+// Written by Francois Fleuret, (C) IDIAP                                //
+// Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
+///////////////////////////////////////////////////////////////////////////
+
+#include "decision_tree.h"
+#include "fusion_sort.h"
+
+DecisionTree::DecisionTree() {
+  _feature_index = -1;
+  _threshold = 0;
+  _weight = 0;
+  _subtree_greater = 0;
+  _subtree_lesser = 0;
+}
+
+DecisionTree::~DecisionTree() {
+  if(_subtree_lesser)
+    delete _subtree_lesser;
+  if(_subtree_greater)
+    delete _subtree_greater;
+}
+
+int DecisionTree::nb_leaves() {
+  if(_subtree_lesser ||_subtree_greater)
+    return _subtree_lesser->nb_leaves() + _subtree_greater->nb_leaves();
+  else
+    return 1;
+}
+
+int DecisionTree::depth() {
+  if(_subtree_lesser ||_subtree_greater)
+    return 1 + max(_subtree_lesser->depth(), _subtree_greater->depth());
+  else
+    return 1;
+}
+
+scalar_t DecisionTree::response(SampleSet *sample_set, int n_sample) {
+  if(_subtree_lesser && _subtree_greater) {
+    if(sample_set->feature_value(n_sample, _feature_index) < _threshold)
+      return _subtree_lesser->response(sample_set, n_sample);
+    else
+      return _subtree_greater->response(sample_set, n_sample);
+  } else {
+    return _weight;
+  }
+}
+
+void DecisionTree::pick_best_split(SampleSet *sample_set, scalar_t *loss_derivatives) {
+
+  int nb_samples = sample_set->nb_samples();
+
+  scalar_t *responses = new scalar_t[nb_samples];
+  int *indexes = new int[nb_samples];
+  int *sorted_indexes = new int[nb_samples];
+
+  scalar_t max_abs_sum = 0;
+  _feature_index = -1;
+
+  for(int f = 0; f < sample_set->nb_features(); f++) {
+    scalar_t sum = 0;
+
+    for(int s = 0; s < nb_samples; s++) {
+      indexes[s] = s;
+      responses[s] = sample_set->feature_value(s, f);
+      sum += loss_derivatives[s];
+    }
+
+    indexed_fusion_sort(nb_samples, indexes, sorted_indexes, responses);
+
+    int t, u = sorted_indexes[0];
+    for(int s = 0; s < nb_samples - 1; s++) {
+      t = u;
+      u = sorted_indexes[s + 1];
+      sum -= 2 * loss_derivatives[t];
+
+      if(responses[t] < responses[u] && abs(sum) > max_abs_sum) {
+        max_abs_sum = abs(sum);
+        _feature_index = f;
+        _threshold = (responses[t] + responses[u])/2;
+      }
+    }
+  }
+
+  delete[] indexes;
+  delete[] sorted_indexes;
+  delete[] responses;
+}
+
+void DecisionTree::train(LossMachine *loss_machine,
+                         SampleSet *sample_set,
+                         scalar_t *current_responses,
+                         scalar_t *loss_derivatives,
+                         int depth) {
+
+  if(_subtree_lesser || _subtree_greater || _feature_index >= 0) {
+    cerr << "You can not re-train a tree." << endl;
+    abort();
+  }
+
+  int nb_samples = sample_set->nb_samples();
+
+  int nb_pos = 0, nb_neg = 0;
+  for(int s = 0; s < sample_set->nb_samples(); s++) {
+    if(sample_set->label(s) > 0) nb_pos++;
+    else if(sample_set->label(s) < 0) nb_neg++;
+  }
+
+  (*global.log_stream) << "Training tree" << endl;
+  (*global.log_stream) << "  nb_samples " << nb_samples << endl;
+  (*global.log_stream) << "  depth " << depth << endl;
+  (*global.log_stream) << "  nb_pos = " << nb_pos << endl;
+  (*global.log_stream) << "  nb_neg = " << nb_neg << endl;
+
+  if(depth >= global.tree_depth_max)
+    (*global.log_stream) << "  Maximum depth reached." << endl;
+  if(nb_pos < min_nb_samples_for_split)
+    (*global.log_stream) << "  Not enough positive samples." << endl;
+  if(nb_neg < min_nb_samples_for_split)
+    (*global.log_stream) << "  Not enough negative samples." << endl;
+
+  if(depth < global.tree_depth_max &&
+     nb_pos >= min_nb_samples_for_split &&
+     nb_neg >= min_nb_samples_for_split) {
+
+    pick_best_split(sample_set, loss_derivatives);
+
+    if(_feature_index >= 0) {
+      int indexes[nb_samples];
+      scalar_t *parted_current_responses = new scalar_t[nb_samples];
+      scalar_t *parted_loss_derivatives = new scalar_t[nb_samples];
+
+      int nb_lesser = 0, nb_greater = 0;
+      int nb_lesser_pos = 0, nb_lesser_neg = 0, nb_greater_pos = 0, nb_greater_neg = 0;
+
+      for(int s = 0; s < nb_samples; s++) {
+        if(sample_set->feature_value(s, _feature_index) < _threshold) {
+          indexes[nb_lesser] = s;
+          parted_current_responses[nb_lesser] = current_responses[s];
+          parted_loss_derivatives[nb_lesser] = loss_derivatives[s];
+
+          if(sample_set->label(s) > 0)
+            nb_lesser_pos++;
+          else if(sample_set->label(s) < 0)
+            nb_lesser_neg++;
+
+          nb_lesser++;
+        } else {
+          nb_greater++;
+
+          indexes[nb_samples - nb_greater] = s;
+          parted_current_responses[nb_samples - nb_greater] = current_responses[s];
+          parted_loss_derivatives[nb_samples - nb_greater] = loss_derivatives[s];
+
+          if(sample_set->label(s) > 0)
+            nb_greater_pos++;
+          else if(sample_set->label(s) < 0)
+            nb_greater_neg++;
+        }
+      }
+
+      if((nb_lesser_pos >= min_nb_samples_for_split ||
+          nb_lesser_neg >= min_nb_samples_for_split) &&
+         (nb_greater_pos >= min_nb_samples_for_split ||
+          nb_greater_neg >= min_nb_samples_for_split)) {
+
+        _subtree_lesser = new DecisionTree();
+
+        {
+          SampleSet sub_sample_set(sample_set, nb_lesser, indexes);
+
+          _subtree_lesser->train(loss_machine,
+                                 &sub_sample_set,
+                                 parted_current_responses,
+                                 parted_loss_derivatives,
+                                 depth + 1);
+        }
+
+        _subtree_greater = new DecisionTree();
+
+        {
+          SampleSet sub_sample_set(sample_set, nb_greater, indexes + nb_lesser);
+
+          _subtree_greater->train(loss_machine,
+                                  &sub_sample_set,
+                                  parted_current_responses + nb_lesser,
+                                  parted_loss_derivatives + nb_lesser,
+                                  depth + 1);
+        }
+      }
+
+      delete[] parted_current_responses;
+      delete[] parted_loss_derivatives;
+    } else {
+      (*global.log_stream) << "Could not find a feature for split." << endl;
+    }
+  }
+
+  if(!(_subtree_greater && _subtree_lesser)) {
+    scalar_t *tmp_responses = new scalar_t[nb_samples];
+    for(int s = 0; s < nb_samples; s++)
+      tmp_responses[s] = 1;
+
+    _weight = loss_machine->optimal_weight(sample_set, tmp_responses, current_responses);
+
+    const scalar_t max_weight = 10.0;
+
+    if(_weight > max_weight) {
+      _weight = max_weight;
+    } else if(_weight < - max_weight) {
+      _weight = - max_weight;
+    }
+
+    (*global.log_stream) << "  _weight " << _weight << endl;
+
+    delete[] tmp_responses;
+  }
+}
+
+void DecisionTree::train(LossMachine *loss_machine,
+                 SampleSet *sample_set,
+                 scalar_t *current_responses) {
+
+  scalar_t *loss_derivatives = new scalar_t[sample_set->nb_samples()];
+
+  loss_machine->get_loss_derivatives(sample_set, current_responses, loss_derivatives);
+
+  train(loss_machine, sample_set, current_responses, loss_derivatives, 0);
+
+  delete[] loss_derivatives;
+}
+
+//////////////////////////////////////////////////////////////////////
+
+void DecisionTree::tag_used_features(bool *used) {
+  if(_subtree_lesser && _subtree_greater) {
+    used[_feature_index] = true;
+    _subtree_lesser->tag_used_features(used);
+    _subtree_greater->tag_used_features(used);
+  }
+}
+
+void DecisionTree::re_index_features(int *new_indexes) {
+  if(_subtree_lesser && _subtree_greater) {
+    _feature_index = new_indexes[_feature_index];
+    _subtree_lesser->re_index_features(new_indexes);
+    _subtree_greater->re_index_features(new_indexes);
+  }
+}
+
+//////////////////////////////////////////////////////////////////////
+
+void DecisionTree::read(istream *is) {
+  if(_subtree_lesser || _subtree_greater) {
+    cerr << "You can not read in an existing tree." << endl;
+    abort();
+  }
+
+  read_var(is, &_feature_index);
+  read_var(is, &_threshold);
+  read_var(is, &_weight);
+
+  int split;
+  read_var(is, &split);
+
+  if(split) {
+    _subtree_lesser = new DecisionTree();
+    _subtree_lesser->read(is);
+    _subtree_greater = new DecisionTree();
+    _subtree_greater->read(is);
+  }
+}
+
+void DecisionTree::write(ostream *os) {
+
+  write_var(os, &_feature_index);
+  write_var(os, &_threshold);
+  write_var(os, &_weight);
+
+  int split;
+  if(_subtree_lesser && _subtree_greater) {
+    split = 1;
+    write_var(os, &split);
+    _subtree_lesser->write(os);
+    _subtree_greater->write(os);
+  } else {
+    split = 0;
+    write_var(os, &split);
+  }
+}