63a65cc208f67348358f492adb2b1e90cb745498
[folded-ctf.git] / loss_machine.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include "tools.h"
22 #include "loss_machine.h"
23
24 LossMachine::LossMachine(int loss_type) {
25   _loss_type = loss_type;
26 }
27
28 void LossMachine::get_loss_derivatives(SampleSet *samples,
29                                        scalar_t *responses,
30                                        scalar_t *derivatives) {
31
32   switch(_loss_type) {
33
34   case LOSS_EXPONENTIAL:
35     {
36       for(int n = 0; n < samples->nb_samples(); n++) {
37         derivatives[n] =
38           - samples->label(n) * exp( - samples->label(n) * responses[n]);
39       }
40     }
41     break;
42
43   case LOSS_HINGE:
44     {
45       for(int n = 0; n < samples->nb_samples(); n++) {
46         if(samples->label(n) != 0 && samples->label(n) * responses[n] < 1)
47           derivatives[n] = 1;
48         else
49           derivatives[n] = 0;
50       }
51     }
52     break;
53
54   case LOSS_LOGISTIC:
55     {
56       for(int n = 0; n < samples->nb_samples(); n++) {
57         if(samples->label(n) == 0)
58           derivatives[n] = 0.0;
59         else
60           derivatives[n] = samples->label(n) * 1/(1 + exp(samples->label(n) * responses[n]));
61       }
62     }
63     break;
64
65   default:
66     cerr << "Unknown loss type in BoostedClassifier::get_loss_derivatives."
67          << endl;
68     exit(1);
69   }
70
71 }
72
73 scalar_t LossMachine::loss(SampleSet *samples, scalar_t *responses) {
74   scalar_t l = 0;
75
76   switch(_loss_type) {
77
78   case LOSS_EXPONENTIAL:
79     {
80       for(int n = 0; n < samples->nb_samples(); n++) {
81         l += exp( - samples->label(n) * responses[n]);
82         ASSERT(!isinf(l));
83       }
84     }
85     break;
86
87   case LOSS_EV_REGULARIZED:
88     {
89       scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos;
90       scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg;
91
92       for(int n = 0; n < samples->nb_samples(); n++) {
93         if(samples->label(n) > 0) {
94           sum_pos += responses[n];
95           sum_sq_pos += sq(responses[n]);
96           nb_pos += 1.0;
97         } else if(samples->label(n) < 0) {
98           sum_neg += responses[n];
99           sum_sq_neg += sq(responses[n]);
100           nb_neg += 1.0;
101         }
102       }
103
104       l = 0;
105
106       if(nb_pos > 0) {
107         m_pos = sum_pos / nb_pos;
108         v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1));
109         l += nb_pos * exp(v_pos/2 - m_pos);
110       }
111
112       if(nb_neg > 0) {
113         m_neg = sum_neg / nb_neg;
114         v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1));
115         l += nb_neg * exp(v_neg/2 + m_neg);
116       }
117
118     }
119     break;
120
121   case LOSS_HINGE:
122     {
123       for(int n = 0; n < samples->nb_samples(); n++) {
124         if(samples->label(n) != 0) {
125           if(samples->label(n) * responses[n] < 1)
126             l += (1 - samples->label(n) * responses[n]);
127         }
128       }
129     }
130     break;
131
132   case LOSS_LOGISTIC:
133     {
134       for(int n = 0; n < samples->nb_samples(); n++) {
135         if(samples->label(n) != 0) {
136           scalar_t u = - samples->label(n) * responses[n];
137           if(u > 20) {
138             l += u;
139           } if(u > -20) {
140             l += log(1 + exp(u));
141           }
142         }
143       }
144     }
145     break;
146
147   default:
148     cerr << "Unknown loss type in LossMachine::loss." << endl;
149     exit(1);
150   }
151
152   return l;
153 }
154
155 scalar_t LossMachine::optimal_weight(SampleSet *sample_set,
156                                      scalar_t *weak_learner_responses,
157                                      scalar_t *current_responses) {
158
159   switch(_loss_type) {
160
161   case LOSS_EXPONENTIAL:
162     {
163       scalar_t num = 0, den = 0, z;
164       for(int n = 0; n < sample_set->nb_samples(); n++) {
165         z = sample_set->label(n) * weak_learner_responses[n];
166         if(z > 0) {
167           num += exp( - sample_set->label(n) * current_responses[n]);
168         } else if(z < 0) {
169           den += exp( - sample_set->label(n) * current_responses[n]);
170         }
171       }
172
173       return 0.5 * log(num / den);
174     }
175     break;
176
177   case LOSS_EV_REGULARIZED:
178     {
179
180       scalar_t u = 0, du = -0.1;
181       scalar_t *responses = new scalar_t[sample_set->nb_samples()];
182
183       scalar_t l, prev_l = -1;
184
185       const scalar_t minimum_delta_for_optimization = 1e-5;
186
187       scalar_t shift = 0;
188
189       {
190         scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos;
191         scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg;
192
193         for(int n = 0; n < sample_set->nb_samples(); n++) {
194           if(sample_set->label(n) > 0) {
195             sum_pos += responses[n];
196             sum_sq_pos += sq(responses[n]);
197             nb_pos += 1.0;
198           } else if(sample_set->label(n) < 0) {
199             sum_neg += responses[n];
200             sum_sq_neg += sq(responses[n]);
201             nb_neg += 1.0;
202           }
203         }
204
205         if(nb_pos > 0) {
206           m_pos = sum_pos / nb_pos;
207           v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1));
208           shift = max(shift, v_pos/2 - m_pos);
209         }
210
211         if(nb_neg > 0) {
212           m_neg = sum_neg / nb_neg;
213           v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1));
214           shift = max(shift, v_neg/2 + m_neg);
215         }
216
217 //         (*global.log_stream) << "nb_pos = " << nb_pos << " nb_neg = " << nb_neg << endl;
218
219       }
220
221       int nb = 0;
222
223       while(nb < 100 && abs(du) > minimum_delta_for_optimization) {
224         nb++;
225
226 //         (*global.log_stream) << "l = " << l << " u = " << u << " du = " << du << endl;
227
228         u += du;
229         for(int s = 0; s < sample_set->nb_samples(); s++) {
230           responses[s] = current_responses[s] + u * weak_learner_responses[s] ;
231         }
232
233         {
234           scalar_t sum_pos = 0, sum_sq_pos = 0, nb_pos = 0, m_pos, v_pos;
235           scalar_t sum_neg = 0, sum_sq_neg = 0, nb_neg = 0, m_neg, v_neg;
236
237           for(int n = 0; n < sample_set->nb_samples(); n++) {
238             if(sample_set->label(n) > 0) {
239               sum_pos += responses[n];
240               sum_sq_pos += sq(responses[n]);
241               nb_pos += 1.0;
242             } else if(sample_set->label(n) < 0) {
243               sum_neg += responses[n];
244               sum_sq_neg += sq(responses[n]);
245               nb_neg += 1.0;
246             }
247           }
248
249           l = 0;
250
251           if(nb_pos > 0) {
252             m_pos = sum_pos / nb_pos;
253             v_pos = sum_sq_pos/(nb_pos - 1) - sq(sum_pos)/(nb_pos * (nb_pos - 1));
254             l += nb_pos * exp(v_pos/2 - m_pos - shift);
255           }
256
257           if(nb_neg > 0) {
258             m_neg = sum_neg / nb_neg;
259             v_neg = sum_sq_neg/(nb_neg - 1) - sq(sum_neg)/(nb_neg * (nb_neg - 1));
260             l += nb_neg * exp(v_neg/2 + m_neg - shift);
261           }
262
263         }
264
265         if(l > prev_l) du = du * -0.25;
266         prev_l = l;
267       }
268
269       delete[] responses;
270
271       return u;
272     }
273
274   case LOSS_HINGE:
275   case LOSS_LOGISTIC:
276     {
277
278       scalar_t u = 0, du = -0.1;
279       scalar_t *responses = new scalar_t[sample_set->nb_samples()];
280
281       scalar_t l, prev_l = -1;
282
283       const scalar_t minimum_delta_for_optimization = 1e-5;
284
285       int n = 0;
286       while(n < 100 && abs(du) > minimum_delta_for_optimization) {
287         n++;
288         u += du;
289         for(int s = 0; s < sample_set->nb_samples(); s++) {
290           responses[s] = current_responses[s] + u * weak_learner_responses[s] ;
291         }
292         l = loss(sample_set, responses);
293         if(l > prev_l) du = du * -0.25;
294         prev_l = l;
295       }
296
297       (*global.log_stream) << "END l = " << l << " du = " << du << endl;
298
299       delete[] responses;
300
301       return u;
302     }
303
304   default:
305     cerr << "Unknown loss type in LossMachine::optimal_weight." << endl;
306     exit(1);
307   }
308
309 }
310
311 void LossMachine::subsample(int nb, scalar_t *labels, scalar_t *responses,
312                             int nb_to_sample, int *sample_nb_occurences, scalar_t *sample_responses,
313                             int allow_duplicates) {
314
315   switch(_loss_type) {
316
317   case LOSS_EXPONENTIAL:
318     {
319       scalar_t *weights = new scalar_t[nb];
320
321       for(int n = 0; n < nb; n++) {
322         if(labels[n] == 0) {
323           weights[n] = 0;
324         } else {
325           weights[n] = exp( - labels[n] * responses[n]);
326         }
327         sample_nb_occurences[n] = 0;
328         sample_responses[n] = 0.0;
329       }
330
331       scalar_t total_weight;
332       int nb_sampled = 0, sum_sample_nb_occurences = 0;
333
334       int *sampled_indexes = new int[nb_to_sample];
335
336       (*global.log_stream) << "Sampling " << nb_to_sample << " samples." << endl;
337
338       do {
339         total_weight = robust_sampling(nb,
340                                        weights,
341                                        nb_to_sample,
342                                        sampled_indexes);
343
344         for(int k = 0; nb_sampled < nb_to_sample && k < nb_to_sample; k++) {
345           int i = sampled_indexes[k];
346           if(allow_duplicates || sample_nb_occurences[i] == 0) nb_sampled++;
347           sample_nb_occurences[i]++;
348           sum_sample_nb_occurences++;
349         }
350       } while(nb_sampled < nb_to_sample);
351
352       (*global.log_stream) << "nb_sampled = " << nb_sampled << " nb_to_sample = " << nb_to_sample << endl;
353
354       (*global.log_stream) << "Done." << endl;
355
356       delete[] sampled_indexes;
357
358       scalar_t unit_weight = log(total_weight / scalar_t(sum_sample_nb_occurences));
359
360       for(int n = 0; n < nb; n++) {
361         if(sample_nb_occurences[n] > 0) {
362           if(allow_duplicates) {
363             sample_responses[n] = - labels[n] * unit_weight;
364           } else {
365             sample_responses[n] = - labels[n] * (unit_weight + log(scalar_t(sample_nb_occurences[n])));
366             sample_nb_occurences[n] = 1;
367           }
368         }
369       }
370
371       delete[] weights;
372
373     }
374     break;
375
376   default:
377     cerr << "Unknown loss type in LossMachine::resample." << endl;
378     exit(1);
379   }
380
381
382 }