Fixed a bug in the threshold given in the ROC output.
[data-tool.git] / data-tool.cc
1
2 /*
3  *  data-tool is a command line tool to do simple statistical
4  *  processing on numerical data.
5  *
6  *  Copyright (c) 2009 Francois Fleuret
7  *  Written by Francois Fleuret <francois@fleuret.org>
8  *
9  *  This file is part of data-tool.
10  *
11  *  data-tool is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  data-tool is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with data-tool.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include <iostream>
26 #include <cmath>
27 #include <stdlib.h>
28 #include <string.h>
29
30 using namespace std;
31
32 struct Couple {
33   int index;
34   double value;
35 };
36
37 int compare_couple(const void *a, const void *b) {
38   if(((Couple *) a)->value < ((Couple *) b)->value) return -1;
39   else if(((Couple *) a)->value > ((Couple *) b)->value) return 1;
40   else return 0;
41 }
42
43 double *inflate_array(double *x, int current_size, int new_size) {
44   double *xx = new double[new_size];
45   for(int n = 0; n < current_size; n++) xx[n] = x[n];
46   delete[] x;
47   return xx;
48 }
49
50 char *next_word(char *buffer, char *r, int buffer_size) {
51   char *s;
52   s = buffer;
53   if(r != NULL)
54     {
55       if(*r == '"') {
56         r++;
57         while((*r != '"') && (*r != '\0') &&
58               (s<buffer+buffer_size-1))
59           *s++ = *r++;
60         if(*r == '"') r++;
61       } else {
62         while((*r != '\r') && (*r != '\n') && (*r != '\0') &&
63               (*r != '\t') && (*r != ' ') && (*r != ',') &&
64               (s<buffer+buffer_size-1))
65           *s++ = *r++;
66       }
67
68       while((*r == ' ') || (*r == '\t') || (*r == ',')) r++;
69       if((*r == '\0') || (*r=='\r') || (*r=='\n')) r = NULL;
70     }
71   *s = '\0';
72   return r;
73 }
74
75 void check_opt(int argc, char **argv, int n_opt, int n, const char *help) {
76   if(n_opt + n >= argc) {
77     cerr << "ERROR: Missing argument for " << argv[n_opt] << ". Expecting " << help << "." << endl;
78     exit(1);
79   }
80 }
81
82 void print_help_and_exit(int e) {
83   cout << "Simple data processing tool. Written by Francois Fleuret." << endl
84        << endl
85        << "This application takes data from the standard input and prints" << endl
86        << "the result on the standard output. It expects either a list of" << endl
87        << "float values (to produce histograms, cumulative distribution functions" << endl
88        << "or the mean, variance, etc.) or a list of couples of values of the form" << endl
89        << "x y on each line (where the sign of x tells the class and y the parameter" << endl
90        << "value) to compute the ROC curve or the ROC curve surface.\n" << endl
91        << "The options are:" << endl
92        << "  --help" << endl
93        << "  --roc" << endl
94        << "  --roc-surface" << endl
95        << "  --normalize" << endl
96        << "  --histo" << endl
97        << "  --cumul" << endl
98        << "  --misc" << endl
99        << "  --auto-extrema" << endl
100        << "  --xbounds <float: xmin> <float: xmax>" << endl
101        << "  --ybounds <float: ymin> <float: ymax>" << endl
102        << "  --nb-bins <int: number of bins>" << endl;
103   exit(e);
104 }
105
106 void check_single_processing(bool unknown_processing) {
107   if(!unknown_processing) {
108     cerr << "ERROR: You can't do two different processings." << endl;
109     exit(1);
110   }
111 }
112
113 int main(int argc, char **argv) {
114   double xmin = 0, xmax = 1, ymin = 0, ymax = 1;
115   int nb_bins = 10;
116   const int buffer_size = 1024;
117
118   char line[buffer_size], token[buffer_size];
119   bool auto_extrema = false;
120   bool normalize = false;
121
122   int i = 1;
123
124   enum { UNKNOWN, ROC, ROC_SURFACE, HISTO, CUMUL, MISC } processing = UNKNOWN;
125
126   // Parsing the command line arguments ////////////////////////////////
127
128   while(i < argc) {
129
130     if(argc == 1 || strcmp(argv[i], "--help") == 0) print_help_and_exit(0);
131
132     else if(strcmp(argv[i], "--roc") == 0) {
133       check_single_processing(processing == UNKNOWN);
134       processing = ROC;
135       i++;
136     }
137
138     else if(strcmp(argv[i], "--roc-surface") == 0) {
139       check_single_processing(processing == UNKNOWN);
140       processing = ROC_SURFACE;
141       i++;
142     }
143
144     else if(strcmp(argv[i], "--cumul") == 0) {
145       check_single_processing(processing == UNKNOWN);
146       processing = CUMUL;
147       i++;
148     }
149
150     else if(strcmp(argv[i], "--normalize") == 0) {
151       normalize = true;
152       i++;
153     }
154
155     else if(strcmp(argv[i], "--histo") == 0) {
156       check_single_processing(processing == UNKNOWN);
157       processing = HISTO;
158       i++;
159     }
160
161     else if(strcmp(argv[i], "--misc") == 0) {
162       check_single_processing(processing == UNKNOWN);
163       processing = MISC;
164       i++;
165     }
166
167     else if(strcmp(argv[i], "--auto-extrema") == 0) {
168       auto_extrema = true;
169       i++;
170     }
171
172     else if(strcmp(argv[i], "--xbounds") == 0) {
173       check_opt(argc, argv, i, 2, "<float: xmin> <float: xmax>");
174       xmin = atof(argv[i+1]);
175       xmax = atof(argv[i+2]);
176       if(xmin >= xmax) {
177         cerr << "ERROR: Incorrect bounds." << endl;
178         exit(1);
179       }
180       i += 3;
181     }
182
183     else if(strcmp(argv[i], "--ybounds") == 0) {
184       check_opt(argc, argv, i, 2, "<float: ymin> <float: ymax>");
185       ymin = atof(argv[i+1]);
186       ymax = atof(argv[i+2]);
187       if(ymin >= ymax) {
188         cerr << "ERROR: Incorrect bounds." << endl;
189         exit(1);
190       }
191       i += 3;
192     }
193
194     else if(strcmp(argv[i], "--nb-bins") == 0) {
195       check_opt(argc, argv, i, 1, "<int: number of bins>");
196       nb_bins = atoi(argv[i+1]);
197       if(nb_bins < 1) {
198         cerr << "ERROR: Incorrect number of bins." << endl;
199         exit(1);
200       }
201       i += 2;
202     }
203
204     else {
205       cerr << "ERROR: Unknown option " << argv[i]  << endl;
206       print_help_and_exit(1);
207     }
208   }
209
210   // Processing the data ///////////////////////////////////////////////
211
212   switch(processing) {
213
214   case CUMUL:
215
216     {
217       int nb_samples = 0, nb_samples_max = 50000;
218       double *x = new double[nb_samples_max];
219
220       while(!cin.eof()) {
221         if(nb_samples == nb_samples_max) {
222           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
223           nb_samples_max = 2 * nb_samples_max;
224         }
225
226         cin.getline(line, buffer_size);
227
228         if(line[0]) {
229           char *s = line;
230           s = next_word(token, s, buffer_size);
231           x[nb_samples] = atof(token);
232           nb_samples++;
233         }
234       }
235
236       Couple tmp[nb_samples];
237       for(int n = 0; n < nb_samples; n++) {
238         tmp[n].index = n;
239         tmp[n].value = x[n];
240       }
241
242       qsort(tmp, nb_samples, sizeof(Couple), compare_couple);
243
244       for(int n = 0; n < nb_samples; n++)
245         cout << tmp[n].value << " " << double(n)/double(nb_samples)  << endl;
246
247       delete[] x;
248
249     }
250
251     break;
252
253   case ROC:
254   case ROC_SURFACE:
255
256     {
257       int nb_samples = 0, nb_samples_max = 1000;
258       double *x = new double[nb_samples_max], *y = new double[nb_samples_max];
259
260       while(!cin.eof()) {
261         if(nb_samples == nb_samples_max) {
262           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
263           y = inflate_array(y, nb_samples_max, 2 * nb_samples_max);
264           nb_samples_max = 2 * nb_samples_max;
265         }
266
267         cin.getline(line, buffer_size);
268
269         if(line[0]) {
270           char *s = line;
271           s = next_word(token, s, buffer_size);
272           x[nb_samples] = atof(token);
273           s = next_word(token, s, buffer_size);
274           y[nb_samples] = atof(token);
275           nb_samples++;
276         }
277       }
278
279       Couple tmp[nb_samples];
280       int nb_rn = 0, nb_rp = 0, nb_fp = 0, nb_fn = 0;
281
282       bool binary = true;
283       for(int n = 0; binary && n < nb_samples; n++) binary &= (x[n] == 0 || x[n] == 1);
284       if(binary) {
285         cerr << "WARNING: your classes are binary, I process them accordingly." << endl;
286         for(int n = 0; n < nb_samples; n++) x[n] = 2 * x[n] - 1;
287       }
288
289       for(int n = 0; n < nb_samples; n++) {
290         tmp[n].index = n;
291         tmp[n].value = y[n];
292         if(x[n] >= 0) nb_rp++;
293         else { nb_rn++; nb_fp++; }
294       }
295
296       if(nb_rp == 0) cerr << "WARNING: No true positive." << endl;
297       if(nb_rn == 0) cerr << "WARNING: No true negative." << endl;
298
299       qsort(tmp, nb_samples, sizeof(Couple), compare_couple);
300
301       if(processing == ROC) {
302         for(int n = 0; n < nb_samples - 1; n++) {
303           if(x[tmp[n].index] >= 0) nb_fn++;
304           else                     nb_fp--;
305           if(tmp[n].value < tmp[n+1].value) {
306             cout << double(nb_fp)/double(nb_rn) << " "
307                  << 1 - double(nb_fn) / double(nb_rp) << " "
308                  << (tmp[n].value + tmp[n+1].value)/2 << " "
309                  << endl;
310           }
311         }
312       } else {
313         double surface = 0;
314         double cx = double(nb_fp)/double(nb_rn), cy = 1 - double(nb_fn) / double(nb_rp);
315         for(int n = 0; n < nb_samples - 1; n++) {
316           if(x[tmp[n].index] >= 0) nb_fn++;
317           else                     nb_fp--;
318           if(tmp[n].value < tmp[n+1].value) {
319             double ncx = double(nb_fp)/double(nb_rn), ncy = 1 - double(nb_fn) / double(nb_rp);
320             surface += (cx - ncx) * cy;
321             cx = ncx; cy = ncy;
322           }
323         }
324         cout << surface  << endl;
325       }
326
327       delete[] x; delete[] y;
328
329     }
330
331     break;
332
333   case HISTO:
334
335     {
336       int nb_samples = 0, nb_samples_max = 1000;
337       double *x = new double[nb_samples_max];
338
339       while(!cin.eof()) {
340         if(nb_samples == nb_samples_max) {
341           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
342           nb_samples_max = 2 * nb_samples_max;
343         }
344
345         cin.getline(line, buffer_size);
346
347         if(line[0]) {
348           char *s = line;
349           s = next_word(token, s, buffer_size);
350           x[nb_samples] = atof(token);
351           if(auto_extrema) {
352             if(nb_samples == 0 || x[nb_samples] > xmax) xmax = x[nb_samples];
353             if(nb_samples == 0 || x[nb_samples] < xmin) xmin = x[nb_samples];
354           }
355           nb_samples++;
356         }
357       }
358
359       int nb[nb_bins];
360       for(int n = 0; n < nb_bins; n++) nb[n] = 0;
361
362       int nb_total = 0;
363       for(int s = 0; s < nb_samples; s++) {
364         int n = int((x[s] - xmin)/(xmax - xmin) * nb_bins);
365         if(n >= 0 && n < nb_bins) nb[n]++;
366         else {
367           cerr << "WARNING: value " << x[s] << " is out of histogram." << endl;
368         }
369         nb_total++;
370       }
371
372       if(normalize) {
373         for(int n = 0; n < nb_bins; n++)
374           cout << xmin + ((xmax - xmin) * n) / double(nb_bins) << " "
375                << (nb[n] / double(nb_total))/((xmax - xmin) / double(nb_bins))  << endl;
376       } else {
377         for(int n = 0; n < nb_bins; n++)
378           cout << xmin + ((xmax - xmin) * n) / double(nb_bins) << " "
379                << nb[n] / double(nb_total)  << endl;
380       }
381     }
382
383     break;
384
385   case MISC:
386
387     {
388       int nb_samples = 0, nb_samples_max = 1000;
389       double *x = new double[nb_samples_max];
390       int nb = 0;
391       double min = 0, max = 0;
392       double sum = 0, sumsq = 0;
393
394       while(!cin.eof()) {
395         if(nb_samples == nb_samples_max) {
396           x = inflate_array(x, nb_samples_max, 2 * nb_samples_max);
397           nb_samples_max = 2 * nb_samples_max;
398         }
399
400         cin.getline(line, buffer_size);
401         char *s = line;
402         if(line[0]) {
403           s = next_word(token, s, buffer_size);
404           x[nb_samples] = atof(token);
405           nb_samples++;
406           double x = atof(token);
407           if(nb == 0 || x > max) max = x;
408           if(nb == 0 || x < min) min = x;
409           sum += x;
410           sumsq += x*x;
411           nb++;
412         }
413       }
414
415       Couple tmp[nb_samples];
416       for(int n = 0; n < nb_samples; n++) {
417         tmp[n].index = n;
418         tmp[n].value = x[n];
419       }
420
421       qsort(tmp, nb_samples, sizeof(Couple), compare_couple);
422
423       delete[] x;
424
425       double mu = sum / double(nb);
426       double sigma = (sumsq - sum * mu) / double(nb - 1);
427       double stdd = sqrt(sigma);
428
429       cout << "MIN " << min
430            << " MAX " << max
431            << " MU " << mu
432            << " SIGMA " << sigma
433            << " STDD " << stdd
434            << " SUM " << sum
435            << " MEDIAN " << tmp[nb_samples/2].value
436            << " QUANTILE0.1 " << tmp[int(nb_samples * 0.1)].value
437            << " QUANTILE0.9 " << tmp[int(nb_samples * 0.9)].value
438            << endl;
439
440     }
441
442     break;
443
444   default:
445     cerr << "ERROR: You must choose a processing type." << endl;
446     exit(1);
447   }
448
449 }