Recoded to UTF-8.
[cmim.git] / cmim.cc
1
2 //////////////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify         //
4 // it under the terms of the version 3 of the GNU General Public License        //
5 // as published by the Free Software Foundation.                                //
6 //                                                                              //
7 // This program is distributed in the hope that it will be useful, but          //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of                   //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU             //
10 // General Public License for more details.                                     //
11 //                                                                              //
12 // You should have received a copy of the GNU General Public License            //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.         //
14 //                                                                              //
15 // Written by Francois Fleuret                                                  //
16 // Copyright (C) Ecole Polytechnique Federale de Lausanne                       //
17 // Contact <francois.fleuret@epfl.ch> for comments & bug reports                //
18 //////////////////////////////////////////////////////////////////////////////////
19
20 // $Id: cmim.cc,v 1.4 2007-08-23 08:36:50 fleuret Exp $
21
22 // This software was developped on GNU/Linux systems with many GPL
23 // tools including emacs, gcc, gdb, and bash (see http://www.fsf.org).
24
25 /*
26
27 To test
28
29 ./cmim --feature-selection cmim --classifier bayesian --error ber --train ./train.dat ./classifier.nb 100
30 ./cmim --test ./test.dat ./classifier.nb ./result.dat
31
32 */
33
34 using namespace std;
35
36 #include <cmath>
37 #include <iostream>
38 #include <fstream>
39 #include <stdio.h>
40 #include <string.h>
41 #include <stdlib.h>
42 #include <sys/time.h>
43
44 #include "classifier.h"
45
46 #define BUFFER_SIZE 256
47
48 FeatureSelector *selector;
49 Classifier *classifier;
50 char classifier_type[BUFFER_SIZE] = "bayesian";
51 char feature_selection_type[BUFFER_SIZE] = "cmim";
52 float reg_param = 0.0;
53 bool verbose = true;
54 bool balanced_error = false;
55 int nb_selected_features = 100;
56
57 void check_opt(int argc, char **argv, int n_opt, int n, const char *help) {
58   if(n_opt+n >= argc) {
59     cerr << "Missing argument for " << argv[n_opt] << ".\n";
60     cerr << "Expecting " << help << ".\n";
61     exit(1);
62   }
63 }
64
65 void train(const DataSet &training_set) {
66   timeval tv_start, tv_end;
67   fe_init_tables();
68
69   if(verbose) {
70     cout << "Selecting features with " << feature_selection_type;
71     cout.flush();
72     gettimeofday(&tv_start, 0);
73   }
74
75   cout.flush();
76
77   selector = new FeatureSelector(nb_selected_features);
78
79   if(strcmp(feature_selection_type, "cmim") == 0) selector->cmim(training_set);
80   else if(strcmp(feature_selection_type, "mim") == 0) selector->mim(training_set);
81   else if(strcmp(feature_selection_type, "random") == 0) selector->random(training_set);
82   else {
83     cerr << "Unknown feature selection type " << feature_selection_type << "\n";
84     exit(1);
85   }
86
87   if(verbose) {
88     gettimeofday(&tv_end, 0);
89     cout << " ("
90          << (float(tv_end.tv_sec - tv_start.tv_sec) * 1000 + float(tv_end.tv_usec - tv_start.tv_usec)/1000)
91          << "ms).\n";
92     gettimeofday(&tv_start, 0);
93     cout << "Learning with " << classifier_type;
94     cout.flush();
95   }
96
97   cout.flush();
98
99   DataSet reduced_training_set(training_set, *selector);
100
101   if(strcmp(classifier_type, "bayesian") == 0) {
102     LinearClassifier *tmp = new LinearClassifier(nb_selected_features);
103     tmp->learn_bayesian(reduced_training_set, balanced_error);
104     classifier = tmp;
105   }
106
107   else if(strcmp(classifier_type, "perceptron") == 0) {
108     LinearClassifier *tmp = new LinearClassifier(nb_selected_features);
109     tmp->learn_perceptron(reduced_training_set, balanced_error);
110     classifier = tmp;
111   }
112
113   else {
114     cerr << "Unknown learning method type " << classifier_type << "\n";
115     exit(1);
116   }
117
118   if(verbose) {
119     gettimeofday(&tv_end, 0);
120     cout << " ("
121          << (float(tv_end.tv_sec - tv_start.tv_sec) * 1000 + float(tv_end.tv_usec - tv_start.tv_usec)/1000)
122          << "ms).\n";
123   }
124
125   cout.flush();
126 }
127
128 int main(int argc, char **argv) {
129   bool arg_error = false;
130
131   int i = 1;
132   while(i < argc && !arg_error) {
133
134     //////////////////////////////////////////////////////////////////////
135     // Parameters ////////////////////////////////////////////////////////
136     //////////////////////////////////////////////////////////////////////
137
138     if(strcmp(argv[i], "--silent") == 0) {
139       verbose = false;
140       i++;
141     }
142
143     else if(strcmp(argv[i], "--feature-selection") == 0) {
144       check_opt(argc, argv, i, 1, "<random|mim|cmim>");
145       strncpy(feature_selection_type, argv[i+1], BUFFER_SIZE);
146       i += 2;
147     }
148
149     else if(strcmp(argv[i], "--classifier") == 0) {
150       check_opt(argc, argv, i, 1, "<bayesian|perceptron>");
151       strncpy(classifier_type, argv[i+1], BUFFER_SIZE);
152       i += 2;
153     }
154
155     else if(strcmp(argv[i], "--error") == 0) {
156       check_opt(argc, argv, i, 1, "<standard|ber>");
157       if(strcmp(argv[i+1], "standard") == 0) balanced_error = false;
158       else if(strcmp(argv[i+1], "ber") == 0) balanced_error = true;
159       else {
160         cerr << "Unknown  error type " << argv[i+1] << "!\n";
161         exit(1);
162       }
163       i += 2;
164     }
165
166     else if(strcmp(argv[i], "--nb-features") == 0) {
167       check_opt(argc, argv, i, 1, "<int: nb features>");
168       nb_selected_features = atoi(argv[i+1]);
169       if(nb_selected_features <= 0) {
170         cerr << "Unconsistent number of selected features (" << nb_selected_features << ").\n";
171         exit(1);
172       }
173       i += 2;
174     }
175
176     //////////////////////////////////////////////////////////////////////
177     // Training //////////////////////////////////////////////////////////
178     //////////////////////////////////////////////////////////////////////
179
180     else if(strcmp(argv[i], "--cross-validation") == 0) {
181       check_opt(argc, argv, i, 3, "<file: data set> <int: nb test samples> <int: nb loops>");
182       if(verbose) {
183         cout << "Loading data.\n";
184         cout.flush();
185       }
186
187       ifstream complete_data(argv[i+1]);
188       if(complete_data.fail()) {
189         cerr << "Can not open " << argv[i+1] << " for reading!\n";
190         exit(1);
191       }
192
193       int nb_for_test = atoi(argv[i+2]);
194       if(nb_for_test <= 0) {
195         cerr << "Unconsistent number of samples for test (" << nb_selected_features << ").\n";
196         exit(1);
197       }
198
199       int nb_cv_loops = atoi(argv[i+3]);
200       if(nb_cv_loops <= 0) {
201         cerr << "Unconsistent number of cross-validation loops (" << nb_cv_loops << ").\n";
202         exit(1);
203       }
204
205       DataSet complete_set(complete_data);
206
207       int n00_test = 0, n01_test = 0, n10_test = 0, n11_test = 0;
208       int n00_train = 0, n01_train = 0, n10_train = 0, n11_train = 0;
209
210       for(int ncv = 0; ncv < nb_cv_loops; ncv++) {
211         bool for_test[complete_set.nb_samples];
212
213         for(int s = 0; s < complete_set.nb_samples; s++) for_test[s] = false;
214         for(int i = 0; i < nb_for_test; i++) {
215           int s;
216           do {
217             s = int(drand48() * complete_set.nb_samples);
218           } while(for_test[s]);
219           for_test[s] = true;
220         }
221
222         DataSet testing_set(complete_set, for_test);
223         for(int s = 0; s < complete_set.nb_samples; s++) for_test[s] = !for_test[s];
224         DataSet training_set(complete_set, for_test);
225
226         train(training_set);
227
228         int n00, n01, n10, n11;
229
230         {
231           float result[training_set.nb_samples];
232           compute_error_rates(selector, classifier, training_set, n00, n01, n10, n11, result);
233           n00_train += n00; n01_train += n01; n10_train += n10; n11_train += n11;
234         }
235
236         {
237           float result[testing_set.nb_samples];
238           compute_error_rates(selector, classifier, testing_set, n00, n01, n10, n11, result);
239           n00_test += n00; n01_test += n01; n10_test += n10; n11_test += n11;
240         }
241
242         delete classifier;
243         delete selector;
244       }
245
246       if(balanced_error) {
247         cout << "BER [" << nb_cv_loops << " loops] "
248                   << " train " << 0.5 * (float(n01_train)/float(n00_train + n01_train) + float(n10_train)/float(n10_train + n11_train))
249                   << " test " << 0.5 * (float(n01_test)/float(n00_test + n01_test) + float(n10_test)/float(n10_test + n11_test)) << "\n";
250       } else {
251         cout << "Error [" << nb_cv_loops << " loops] "
252                   << " train " << float(n01_train + n10_train)/float(n00_train + n01_train + n10_train + n11_train)
253                   << " test " << float(n01_test + n10_test)/float(n00_test + n01_test + n10_test + n11_test) << "\n";
254       }
255
256       i += 4;
257     }
258
259     //////////////////////////////////////////////////////////////////////
260
261     else if(strcmp(argv[i], "--train") == 0) {
262       check_opt(argc, argv, i, 2, "<file: data set> <file: classifier>");
263
264       if(verbose) {
265         cout << "Loading data.\n";
266         cout.flush();
267       }
268
269       ifstream training_data(argv[i+1]);
270       if(training_data.fail()) {
271         cerr << "Can not open " << argv[i+1] << " for reading!\n";
272         exit(1);
273       }
274
275       DataSet training_set(training_data);
276
277       //////////////////////////////////////////////////////////////////////
278       // Learning with CMIM + naive Bayesian ///////////////////////////////
279       //////////////////////////////////////////////////////////////////////
280
281       train(training_set);
282
283       //////////////////////////////////////////////////////////////////////
284       // Finishing and saving //////////////////////////////////////////////
285       //////////////////////////////////////////////////////////////////////
286
287       if(verbose) cout << "Saving the classifier in [" << argv[i+2] << "].\n";
288       ofstream classifier_out(argv[i+2]);
289       if(classifier_out.fail()) {
290         cerr << "Can not open " << argv[i+2] << " for writing!\n";
291         exit(1);
292       }
293
294       selector->save(classifier_out);
295       classifier->save(classifier_out);
296
297       delete classifier;
298       delete selector;
299
300       i += 3;
301     }
302
303     //////////////////////////////////////////////////////////////////////
304     // Test //////////////////////////////////////////////////////////////
305     //////////////////////////////////////////////////////////////////////
306
307     else if(strcmp(argv[i], "--test") == 0) {
308       check_opt(argc, argv, i, 3, "<file: classifier> <file: data set> <file: result>");
309
310       // Load the classifier
311
312       if(verbose) cout << "Loading the classifier from [" << argv[i+1] << "].\n";
313
314       ifstream classifier_in(argv[i+1]);
315       if(classifier_in.fail()) {
316         cerr << "Can not open " << argv[i+1] << " for reading!\n";
317         exit(1);
318       }
319
320       selector = new FeatureSelector(classifier_in);
321       classifier = Classifier::load(classifier_in);
322
323       // Load the testing data
324
325       ifstream testing_data(argv[i+2]);
326       if(testing_data.fail()) {
327         cerr << "Can not open " << argv[i+2] << " for reading!\n";
328         exit(1);
329       }
330
331       ofstream result_out(argv[i+3]);
332       if(result_out.fail()) {
333         cerr << "Can not open " << argv[i+3] << " for writing!\n";
334         exit(1);
335       }
336
337       DataSet testing_set(testing_data);
338
339       // Compute the predicted responses
340
341       int n00, n01, n10, n11;
342       float result[testing_set.nb_samples];
343       compute_error_rates(selector, classifier, testing_set, n00, n01, n10, n11, result);
344
345       for(int s = 0; s < testing_set.nb_samples; s++)
346         result_out << result[s] << "\n";
347
348       cout << "ERROR " << float(n01 + n10)/float(n00 + n01 + n10 + n11) << "\n";
349       cout << "BER   " << 0.5 * (float(n01)/float(n00 + n01) + float(n10)/float(n10 + n11)) << "\n";
350       cout << "FN    " << float(n10)/float(n10+n11) << "\n";
351       cout << "FP    " << float(n01)/float(n01+n00) << "\n";
352       cout << "real_0_predicted_0 " << n00 << "\n";
353       cout << "real_0_predicted_1 " << n01 << "\n";
354       cout << "real_1_predicted_0 " << n10 << "\n";
355       cout << "real_1_predicted_1 " << n11 << "\n";
356
357       delete classifier;
358       delete selector;
359
360       i += 4;
361     }
362
363     else arg_error = true;
364   }
365
366   if(arg_error) {
367     cerr << "Conditional Mutual Information Maximization\n";
368     cerr << "Written by François Fleuret (c) EPFL 2004\n";
369     cerr << "Comments and bug reports to <francois.fleuret@epfl.ch>\n";
370     cerr << "\n";
371     cerr << "Usage: " << argv[0] << "\n";
372     cerr << "--silent\n";
373     cerr << "--feature-selection <random|mim|cmim>\n";
374     cerr << "--classifier <bayesian|perceptron>\n";
375     cerr << "--error <standard|ber>\n";
376     cerr << "--nb-features <int: nb of features>\n";
377     cerr << "--cross-validation <file: data set> <int: nb test samples> <int: nb loops>\n";
378     cerr << "--train <file: data set> <file: classifier>\n";
379     cerr << "--test <file: classifier> <file: data set> <file: result>\n";
380     exit(1);
381   }
382 }