X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=folded-ctf.git;a=blobdiff_plain;f=loss_machine.cc;h=413a2a314f785edfab7397cc61b1e0cd6ca8e345;hp=63a65cc208f67348358f492adb2b1e90cb745498;hb=HEAD;hpb=5b78a555f6c7ff20a71d0520db63bc43e69e1f41 diff --git a/loss_machine.cc b/loss_machine.cc index 63a65cc..413a2a3 100644 --- a/loss_machine.cc +++ b/loss_machine.cc @@ -1,22 +1,26 @@ - -/////////////////////////////////////////////////////////////////////////// -// This program is free software: you can redistribute it and/or modify // -// it under the terms of the version 3 of the GNU General Public License // -// as published by the Free Software Foundation. // -// // -// This program is distributed in the hope that it will be useful, but // -// WITHOUT ANY WARRANTY; without even the implied warranty of // -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // -// General Public License for more details. // -// // -// You should have received a copy of the GNU General Public License // -// along with this program. If not, see . // -// // -// Written by Francois Fleuret // -// (C) Idiap Research Institute // -// // -// Contact for comments & bug reports // -/////////////////////////////////////////////////////////////////////////// +/* + * folded-ctf is an implementation of the folded hierarchy of + * classifiers for object detection, developed by Francois Fleuret + * and Donald Geman. + * + * Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/ + * Written by Francois Fleuret + * + * This file is part of folded-ctf. + * + * folded-ctf is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * folded-ctf is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with folded-ctf. If not, see . + * + */ #include "tools.h" #include "loss_machine.h" @@ -84,40 +88,6 @@ scalar_t LossMachine::loss(SampleSet *samples, scalar_t *responses) { } break; - case LOSS_EV_REGULARIZED: - { - scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos; - scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg; - - for(int n = 0; n < samples->nb_samples(); n++) { - if(samples->label(n) > 0) { - sum_pos += responses[n]; - sum_sq_pos += sq(responses[n]); - nb_pos += 1.0; - } else if(samples->label(n) < 0) { - sum_neg += responses[n]; - sum_sq_neg += sq(responses[n]); - nb_neg += 1.0; - } - } - - l = 0; - - if(nb_pos > 0) { - m_pos = sum_pos / nb_pos; - v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1)); - l += nb_pos * exp(v_pos/2 - m_pos); - } - - if(nb_neg > 0) { - m_neg = sum_neg / nb_neg; - v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1)); - l += nb_neg * exp(v_neg/2 + m_neg); - } - - } - break; - case LOSS_HINGE: { for(int n = 0; n < samples->nb_samples(); n++) { @@ -161,6 +131,7 @@ scalar_t LossMachine::optimal_weight(SampleSet *sample_set, case LOSS_EXPONENTIAL: { scalar_t num = 0, den = 0, z; + for(int n = 0; n < sample_set->nb_samples(); n++) { z = sample_set->label(n) * weak_learner_responses[n]; if(z > 0) { @@ -174,103 +145,6 @@ scalar_t LossMachine::optimal_weight(SampleSet *sample_set, } break; - case LOSS_EV_REGULARIZED: - { - - scalar_t u = 0, du = -0.1; - scalar_t *responses = new scalar_t[sample_set->nb_samples()]; - - scalar_t l, prev_l = -1; - - const scalar_t minimum_delta_for_optimization = 1e-5; - - scalar_t shift = 0; - - { - scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos; - scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg; - - for(int n = 0; n < sample_set->nb_samples(); n++) { - if(sample_set->label(n) > 0) { - sum_pos += responses[n]; - sum_sq_pos += sq(responses[n]); - nb_pos += 1.0; - } else if(sample_set->label(n) < 0) { - sum_neg += responses[n]; - sum_sq_neg += sq(responses[n]); - nb_neg += 1.0; - } - } - - if(nb_pos > 0) { - m_pos = sum_pos / nb_pos; - v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1)); - shift = max(shift, v_pos/2 - m_pos); - } - - if(nb_neg > 0) { - m_neg = sum_neg / nb_neg; - v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1)); - shift = max(shift, v_neg/2 + m_neg); - } - -// (*global.log_stream) << "nb_pos = " << nb_pos << " nb_neg = " << nb_neg << endl; - - } - - int nb = 0; - - while(nb < 100 && abs(du) > minimum_delta_for_optimization) { - nb++; - -// (*global.log_stream) << "l = " << l << " u = " << u << " du = " << du << endl; - - u += du; - for(int s = 0; s < sample_set->nb_samples(); s++) { - responses[s] = current_responses[s] + u * weak_learner_responses[s] ; - } - - { - scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos; - scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg; - - for(int n = 0; n < sample_set->nb_samples(); n++) { - if(sample_set->label(n) > 0) { - sum_pos += responses[n]; - sum_sq_pos += sq(responses[n]); - nb_pos += 1.0; - } else if(sample_set->label(n) < 0) { - sum_neg += responses[n]; - sum_sq_neg += sq(responses[n]); - nb_neg += 1.0; - } - } - - l = 0; - - if(nb_pos > 0) { - m_pos = sum_pos / nb_pos; - v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1)); - l += nb_pos * exp(v_pos/2 - m_pos - shift); - } - - if(nb_neg > 0) { - m_neg = sum_neg / nb_neg; - v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1)); - l += nb_neg * exp(v_neg/2 + m_neg - shift); - } - - } - - if(l > prev_l) du = du * -0.25; - prev_l = l; - } - - delete[] responses; - - return u; - } - case LOSS_HINGE: case LOSS_LOGISTIC: { @@ -349,8 +223,6 @@ void LossMachine::subsample(int nb, scalar_t *labels, scalar_t *responses, } } while(nb_sampled < nb_to_sample); - (*global.log_stream) << "nb_sampled = " << nb_sampled << " nb_to_sample = " << nb_to_sample << endl; - (*global.log_stream) << "Done." << endl; delete[] sampled_indexes;