automatic commit
[folded-ctf.git] / folding.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret                                           //
16 // (C) Idiap Research Institute                                          //
17 //                                                                       //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
19 ///////////////////////////////////////////////////////////////////////////
20
21 #include <iostream>
22 #include <fstream>
23 #include <cmath>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27
28 using namespace std;
29
30 #include "misc.h"
31 #include "param_parser.h"
32 #include "global.h"
33 #include "labelled_image_pool_file.h"
34 #include "labelled_image_pool_subset.h"
35 #include "tools.h"
36 #include "detector.h"
37 #include "pose_cell_hierarchy.h"
38 #include "error_rates.h"
39 #include "materials.h"
40
41 //////////////////////////////////////////////////////////////////////
42
43 void check(bool condition, const char *message) {
44   if(!condition) {
45     cerr << message << endl;
46     exit(1);
47   }
48 }
49
50 //////////////////////////////////////////////////////////////////////
51
52 int main(int argc, char **argv) {
53   char *new_argv[argc];
54   int new_argc = 0;
55
56 #ifdef DEBUG
57   cout << endl;
58   cout << "**********************************************************************" << endl;
59   cout << "**                     COMPILED IN DEBUG MODE                       **" << endl;
60   cout << "**********************************************************************" << endl;
61   cout << endl;
62 #endif
63
64   cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
65   for(int i = 0; i < argc; i++)
66     cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
67          << endl;
68
69   {
70     ParamParser parser;
71     global.init_parser(&parser);
72     parser.parse_options(argc, argv, false, &new_argc, new_argv);
73     global.read_parser(&parser);
74     (*global.log_stream)
75       << "-- PARAMETERS --------------------------------------------------------"
76       << endl;
77     parser.print_all(global.log_stream);
78   }
79
80   nice(global.niceness);
81
82   (*global.log_stream) << "INFO RANDOM_SEED " << global.random_seed << endl;
83   srand48(global.random_seed);
84
85   LabelledImagePool *main_pool = 0;
86   LabelledImagePool *train_pool = 0, *validation_pool = 0, *hierarchy_pool = 0;
87   LabelledImagePool *test_pool = 0;
88   Detector *detector = 0;
89
90   {
91     char buffer[buffer_size];
92     gethostname(buffer, buffer_size);
93     (*global.log_stream) << "INFO HOSTNAME " << buffer << endl;
94   }
95
96   for(int c = 1; c < new_argc; c++) {
97
98     if(strcmp(new_argv[c], "open-pool") == 0) {
99       cout
100         << "-- OPENING POOL ------------------------------------------------------"
101         << endl;
102
103       check(!main_pool, "Pool already opened.");
104       check(global.pool_name[0], "No pool file.");
105
106       main_pool = new LabelledImagePoolFile(global.pool_name);
107
108       bool for_test[main_pool->nb_images()];
109       bool for_train[main_pool->nb_images()];
110       bool for_validation[main_pool->nb_images()];
111       bool for_hierarchy[main_pool->nb_images()];
112
113       for(int n = 0; n < main_pool->nb_images(); n++) {
114         for_test[n] = false;
115         for_train[n] = false;
116         for_validation[n] = false;
117         scalar_t r = drand48();
118         if(r < global.proportion_for_train)
119           for_train[n] = true;
120         else if(r < global.proportion_for_train + global.proportion_for_validation)
121           for_validation[n] = true;
122         else if(global.proportion_for_test < 0 ||
123                 r < global.proportion_for_train +
124                 global.proportion_for_validation +
125                 global.proportion_for_test)
126           for_test[n] = true;
127         for_hierarchy[n] = for_train[n] || for_validation[n];
128       }
129
130       train_pool = new LabelledImagePoolSubset(main_pool, for_train);
131       validation_pool = new LabelledImagePoolSubset(main_pool, for_validation);
132       hierarchy_pool = new LabelledImagePoolSubset(main_pool, for_hierarchy);
133
134       if(global.test_pool_name[0]) {
135         test_pool = new LabelledImagePoolFile(global.test_pool_name);
136       } else {
137         test_pool = new LabelledImagePoolSubset(main_pool, for_test);
138       }
139
140       cout << "Using "
141            << train_pool->nb_images() << " images for train, "
142            << validation_pool->nb_images() << " images for validation, "
143            << hierarchy_pool->nb_images() << " images for the hierarchy and "
144            << test_pool->nb_images() << " images for test."
145            << endl;
146
147     }
148
149     //////////////////////////////////////////////////////////////////////
150
151     else if(strcmp(new_argv[c], "train-detector") == 0) {
152       cout << "-- TRAIN DETECTOR ----------------------------------------------------" << endl;
153       check(train_pool, "No train pool available.");
154       check(validation_pool, "No validation pool available.");
155       check(hierarchy_pool, "No hierarchy pool available.");
156       check(!detector, "Existing detector, can not train another one.");
157       detector = new Detector();
158       detector->train(train_pool, validation_pool, hierarchy_pool);
159     }
160
161     else if(strcmp(new_argv[c], "compute-thresholds") == 0) {
162       cout << "-- COMPUTE THRESHOLDS ------------------------------------------------" << endl;
163       check(validation_pool, "No validation pool available.");
164       check(detector, "No detector.");
165       detector->compute_thresholds(validation_pool, global.wanted_true_positive_rate);
166     }
167
168     //////////////////////////////////////////////////////////////////////
169
170     else if(strcmp(new_argv[c], "test-detector") == 0) {
171       cout << "-- TEST DETECTOR -----------------------------------------------------" << endl;
172
173       check(test_pool, "No test pool available.");
174       check(detector, "No detector.");
175
176       if(test_pool->nb_images() > 0) {
177         print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
178       } else {
179         cout << "No test image." << endl;
180       }
181     }
182
183     //////////////////////////////////////////////////////////////////////
184
185     else if(strcmp(new_argv[c], "sequence-test-detector") == 0) {
186       cout << "-- SEQUENCE TEST DETECTOR --------------------------------------------" << endl;
187
188       check(test_pool, "No test pool available.");
189       check(detector, "No detector.");
190
191       if(test_pool->nb_images() > 0) {
192
193         for(int n = 0; n < global.nb_wanted_true_positive_rates; n++) {
194           scalar_t r = global.wanted_true_positive_rate *
195             scalar_t(n + 1) / scalar_t(global.nb_wanted_true_positive_rates);
196           cout << "Testint at tp " << r
197                << " (" << n + 1 << "/" << global.nb_wanted_true_positive_rates << ")"
198                << endl;
199           (*global.log_stream) << "INFO THRESHOLD_FOR_TP " << r << endl;
200           detector->compute_thresholds(validation_pool, r);
201           print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
202         }
203       } else {
204         cout << "No test image." << endl;
205       }
206     }
207
208     //////////////////////////////////////////////////////////////////////
209
210     else if(strcmp(new_argv[c], "write-detector") == 0) {
211       cout << "-- WRITE DETECTOR ----------------------------------------------------" << endl;
212       ofstream out(global.detector_name);
213       if(out.fail()) {
214         cerr << "Can not write to " << global.detector_name << endl;
215         exit(1);
216       }
217       check(detector, "No detector available.");
218       detector->write(&out);
219     }
220
221     //////////////////////////////////////////////////////////////////////
222
223     else if(strcmp(new_argv[c], "read-detector") == 0) {
224       cout << "-- READ DETECTOR -----------------------------------------------------" << endl;
225
226       check(!detector, "Existing detector, can not load another one.");
227
228       ifstream in(global.detector_name);
229       if(in.fail()) {
230         cerr << "Can not read from " << global.detector_name << endl;
231         exit(1);
232       }
233
234       detector = new Detector();
235       detector->read(&in);
236     }
237
238     //////////////////////////////////////////////////////////////////////
239
240     else if(strcmp(new_argv[c], "write-pool-images") == 0) {
241       cout << "-- WRITING POOL IMAGES -----------------------------------------------" << endl;
242       check(global.nb_images > 0, "You must set nb_images to a positive value.");
243       check(train_pool, "No train pool available.");
244       write_pool_images_with_poses_and_referentials(train_pool, detector);
245     }
246
247     //////////////////////////////////////////////////////////////////////
248
249     else {
250       cerr << "Unknown action " << new_argv[c] << endl;
251       exit(1);
252     }
253
254     //////////////////////////////////////////////////////////////////////
255
256   }
257
258   delete detector;
259
260   delete train_pool;
261   delete validation_pool;
262   delete hierarchy_pool;
263   delete test_pool;
264
265   delete main_pool;
266
267   cout << "-- FINISHED ----------------------------------------------------------" << endl;
268
269 }