automatic commit
[folded-ctf.git] / tools.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include "misc.h"
22 #include "tools.h"
23 #include "fusion_sort.h"
24
25 scalar_t robust_sampling(int nb, scalar_t *weights, int nb_to_sample, int *sampled) {
26   ASSERT(nb > 0);
27   if(nb == 1) {
28     for(int k = 0; k < nb_to_sample; k++) sampled[k] = 0;
29     return weights[0];
30   } else {
31     scalar_t *pair_weights = new scalar_t[(nb+1)/2];
32     for(int k = 0; k < nb/2; k++)
33       pair_weights[k] = weights[2 * k] + weights[2 * k + 1];
34     if(nb%2)
35       pair_weights[(nb+1)/2 - 1] = weights[nb-1];
36     scalar_t result = robust_sampling((nb+1)/2, pair_weights, nb_to_sample, sampled);
37     for(int k = 0; k < nb_to_sample; k++) {
38       int s = sampled[k];
39       // There is a bit of a trick for the isolated sample in the odd
40       // case. Since the corresponding pair weight is the same as the
41       // one sample alone, the test is always true and the isolated
42       // sample will be taken for sure.
43       if(drand48() * pair_weights[s] <= weights[2 * s])
44         sampled[k] = 2 * s;
45       else
46         sampled[k] = 2 * s + 1;
47     }
48     delete[] pair_weights;
49     return result;
50   }
51 }
52
53 void print_roc_small_pos(ostream *out,
54                          int nb_pos, scalar_t *pos_responses,
55                          int nb_neg, scalar_t *neg_responses,
56                          scalar_t fas_factor) {
57
58   scalar_t *sorted_pos_responses = new scalar_t[nb_pos];
59
60   fusion_sort(nb_pos, pos_responses, sorted_pos_responses);
61
62   int *bins = new int[nb_pos + 1];
63   for(int k = 0; k <= nb_pos; k++) bins[k] = 0;
64
65   for(int k = 0; k < nb_neg; k++) {
66     scalar_t r = neg_responses[k];
67
68     if(r < sorted_pos_responses[0])
69       bins[0]++;
70
71     else if(r >= sorted_pos_responses[nb_pos - 1])
72       bins[nb_pos]++;
73
74     else {
75       int a = 0;
76       int b = nb_pos - 1;
77       int c = 0;
78
79       while(a < b - 1) {
80         c = (a + b) / 2;
81         if(r < sorted_pos_responses[c])
82           b = c;
83         else
84           a = c;
85       }
86
87       // Beware of identical positive responses
88       while(c < nb_pos && r >= sorted_pos_responses[c])
89         c++;
90
91       bins[c]++;
92     }
93   }
94
95   int s = nb_neg;
96   for(int k = 0; k < nb_pos; k++) {
97     s -= bins[k];
98     if(k == 0 || sorted_pos_responses[k-1] < sorted_pos_responses[k]) {
99       (*out) << (scalar_t(s) / scalar_t(nb_neg)) * fas_factor
100              << " "
101              << scalar_t(nb_pos - k)/scalar_t(nb_pos)
102              << " "
103              << sorted_pos_responses[k]
104              << endl;
105     }
106   }
107
108   delete[] bins;
109   delete[] sorted_pos_responses;
110 }