*** empty log message ***
[folded-ctf.git] / tools.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "misc.h"
20 #include "tools.h"
21 #include "fusion_sort.h"
22
23 scalar_t robust_sampling(int nb, scalar_t *weights, int nb_to_sample, int *sampled) {
24   ASSERT(nb > 0);
25   if(nb == 1) {
26     for(int k = 0; k < nb_to_sample; k++) sampled[k] = 0;
27     return weights[0];
28   } else {
29     scalar_t *pair_weights = new scalar_t[(nb+1)/2];
30     for(int k = 0; k < nb/2; k++)
31       pair_weights[k] = weights[2 * k] + weights[2 * k + 1];
32     if(nb%2)
33       pair_weights[(nb+1)/2 - 1] = weights[nb-1];
34     scalar_t result = robust_sampling((nb+1)/2, pair_weights, nb_to_sample, sampled);
35     for(int k = 0; k < nb_to_sample; k++) {
36       int s = sampled[k];
37       // There is a bit of a trick for the isolated sample in the odd
38       // case. Since the corresponding pair weight is the same as the
39       // one sample alone, the test is always true and the isolated
40       // sample will be taken for sure.
41       if(drand48() * pair_weights[s] <= weights[2 * s])
42         sampled[k] = 2 * s;
43       else
44         sampled[k] = 2 * s + 1;
45     }
46     delete[] pair_weights;
47     return result;
48   }
49 }
50
51 void print_roc_small_pos(ostream *out,
52                          int nb_pos, scalar_t *pos_responses,
53                          int nb_neg, scalar_t *neg_responses,
54                          scalar_t fas_factor) {
55
56   scalar_t *sorted_pos_responses = new scalar_t[nb_pos];
57
58   fusion_sort(nb_pos, pos_responses, sorted_pos_responses);
59
60   int *bins = new int[nb_pos + 1];
61   for(int k = 0; k <= nb_pos; k++) bins[k] = 0;
62
63   for(int k = 0; k < nb_neg; k++) {
64     scalar_t r = neg_responses[k];
65
66     if(r < sorted_pos_responses[0])
67       bins[0]++;
68
69     else if(r >= sorted_pos_responses[nb_pos - 1])
70       bins[nb_pos]++;
71
72     else {
73       int a = 0;
74       int b = nb_pos - 1;
75       int c = 0;
76
77       while(a < b - 1) {
78         c = (a + b) / 2;
79         if(r < sorted_pos_responses[c])
80           b = c;
81         else
82           a = c;
83       }
84
85       // Beware of identical positive responses
86       while(c < nb_pos && r >= sorted_pos_responses[c])
87         c++;
88
89       bins[c]++;
90     }
91   }
92
93   int s = nb_neg;
94   for(int k = 0; k < nb_pos; k++) {
95     s -= bins[k];
96     if(k == 0 || sorted_pos_responses[k-1] < sorted_pos_responses[k]) {
97       (*out) << (scalar_t(s) / scalar_t(nb_neg)) * fas_factor
98              << " "
99              << scalar_t(nb_pos - k)/scalar_t(nb_pos)
100              << " "
101              << sorted_pos_responses[k]
102              << endl;
103     }
104   }
105
106   delete[] bins;
107   delete[] sorted_pos_responses;
108 }